Repository: ami-iit/jaxsim
Branch: main
Commit: 62ac1bc2707f
Files: 136
Total size: 988.2 KB
Directory structure:
gitextract_5j4outfv/
├── .devcontainer/
│ ├── Dockerfile
│ └── devcontainer.json
├── .gitattributes
├── .github/
│ ├── CODEOWNERS
│ ├── dependabot.yml
│ ├── release.yml
│ └── workflows/
│ ├── ci_cd.yml
│ ├── gpu_benchmark.yml
│ ├── pixi.yml
│ └── read_the_docs.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── CITATION.cff
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs/
│ ├── Makefile
│ ├── conf.py
│ ├── examples.rst
│ ├── guide/
│ │ ├── configuration.rst
│ │ └── install.rst
│ ├── index.rst
│ ├── make.bat
│ └── modules/
│ ├── api.rst
│ ├── math.rst
│ ├── mujoco.rst
│ ├── parsers.rst
│ ├── rbda.rst
│ ├── typing.rst
│ └── utils.rst
├── environment.yml
├── examples/
│ ├── .gitattributes
│ ├── .gitignore
│ ├── README.md
│ ├── assets/
│ │ ├── build_cartpole_urdf.py
│ │ └── cartpole.urdf
│ ├── jaxsim_as_multibody_dynamics_library.ipynb
│ ├── jaxsim_as_physics_engine.ipynb
│ ├── jaxsim_as_physics_engine_advanced.ipynb
│ └── jaxsim_for_robot_controllers.ipynb
├── pyproject.toml
├── src/
│ └── jaxsim/
│ ├── __init__.py
│ ├── api/
│ │ ├── __init__.py
│ │ ├── actuation_model.py
│ │ ├── com.py
│ │ ├── common.py
│ │ ├── contact.py
│ │ ├── data.py
│ │ ├── frame.py
│ │ ├── integrators.py
│ │ ├── joint.py
│ │ ├── kin_dyn_parameters.py
│ │ ├── link.py
│ │ ├── model.py
│ │ ├── ode.py
│ │ └── references.py
│ ├── exceptions.py
│ ├── logging.py
│ ├── math/
│ │ ├── __init__.py
│ │ ├── adjoint.py
│ │ ├── cross.py
│ │ ├── inertia.py
│ │ ├── joint_model.py
│ │ ├── quaternion.py
│ │ ├── rotation.py
│ │ ├── skew.py
│ │ ├── transform.py
│ │ └── utils.py
│ ├── mujoco/
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ ├── loaders.py
│ │ ├── model.py
│ │ ├── utils.py
│ │ └── visualizer.py
│ ├── parsers/
│ │ ├── __init__.py
│ │ ├── descriptions/
│ │ │ ├── __init__.py
│ │ │ ├── collision.py
│ │ │ ├── joint.py
│ │ │ ├── link.py
│ │ │ └── model.py
│ │ ├── kinematic_graph.py
│ │ └── rod/
│ │ ├── __init__.py
│ │ ├── meshes.py
│ │ ├── parser.py
│ │ └── utils.py
│ ├── rbda/
│ │ ├── __init__.py
│ │ ├── aba.py
│ │ ├── aba_parallel.py
│ │ ├── actuation/
│ │ │ ├── __init__.py
│ │ │ └── common.py
│ │ ├── collidable_points.py
│ │ ├── contacts/
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── relaxed_rigid.py
│ │ │ ├── rigid.py
│ │ │ └── soft.py
│ │ ├── crba.py
│ │ ├── forward_kinematics.py
│ │ ├── forward_kinematics_parallel.py
│ │ ├── jacobian.py
│ │ ├── kinematic_constraints.py
│ │ ├── mass_inverse.py
│ │ ├── rnea.py
│ │ └── utils.py
│ ├── terrain/
│ │ ├── __init__.py
│ │ └── terrain.py
│ ├── typing.py
│ └── utils/
│ ├── __init__.py
│ ├── jaxsim_dataclass.py
│ ├── tracing.py
│ └── wrappers.py
└── tests/
├── __init__.py
├── assets/
│ ├── 4_bar_opened.urdf
│ ├── cube.stl
│ ├── double_pendulum.sdf
│ ├── mixed_shapes_robot.urdf
│ └── test_cube.urdf
├── conftest.py
├── test_actuation.py
├── test_api_com.py
├── test_api_contact.py
├── test_api_data.py
├── test_api_frame.py
├── test_api_joint.py
├── test_api_link.py
├── test_api_model.py
├── test_api_model_hw_parametrization.py
├── test_automatic_differentiation.py
├── test_benchmark.py
├── test_exceptions.py
├── test_meshes.py
├── test_pytree.py
├── test_simulations.py
├── test_visualizer.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .devcontainer/Dockerfile
================================================
# syntax=docker/dockerfile:1.4
FROM mcr.microsoft.com/devcontainers/base:jammy
ARG PROJECT_NAME=jaxsim
ARG PIXI_VERSION=v0.35.0
RUN curl -o /usr/local/bin/pixi -SL https://github.com/prefix-dev/pixi/releases/download/${PIXI_VERSION}/pixi-$(uname -m)-unknown-linux-musl \
&& chmod +x /usr/local/bin/pixi \
&& pixi info
# Add LFS repository and install.
RUN apt-get update && apt-get install -y curl \
&& curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash \
&& apt install -y git-lfs
USER vscode
WORKDIR /home/vscode
RUN echo 'eval "$(pixi completion -s bash)"' >> /home/vscode/.bashrc
================================================
FILE: .devcontainer/devcontainer.json
================================================
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/ubuntu
{
"name": "Ubuntu",
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
"build": {
"context": "..",
"dockerfile": "Dockerfile"
},
// Features to add to the dev container. More info: https://containers.dev/features.
"features": {
"ghcr.io/devcontainers/features/docker-in-docker:2": {}
},
// Put `.pixi` folder in a mounted volume of a case-insensitive filesystem.
"mounts": ["source=${localWorkspaceFolderBasename}-pixi,target=${containerWorkspaceFolder}/.pixi,type=volume"],
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],
// Use 'postCreateCommand' to run commands after the container is created.
"postCreateCommand": "sudo chown vscode .pixi && git lfs pull --include='pixi.lock' && pixi install --environment=test-cpu",
// Configure tool-specific properties.
// "customizations": {},
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
// VSCode extensions
"customizations": {
"vscode": {
"settings": {
"python.pythonPath": "/workspaces/jaxsim/.pixi/envs/test-cpu/bin/python",
"python.defaultInterpreterPath": "/workspaces/jaxsim/.pixi/envs/test-cpu/bin/python",
"python.terminal.activateEnvironment": true,
"python.terminal.activateEnvInCurrentTerminal": true
},
"extensions": [
"ms-python.python",
"donjayamanne.python-extension-pack",
"ms-toolsai.jupyter",
"GitHub.codespaces",
"GitHub.copilot",
"ms-azuretools.vscode-docker",
"charliermarsh.ruff"
]
}
}
}
================================================
FILE: .gitattributes
================================================
# GitHub syntax highlighting
pixi.lock filter=lfs diff=lfs merge=lfs -text
================================================
FILE: .github/CODEOWNERS
================================================
* @flferretti
================================================
FILE: .github/dependabot.yml
================================================
version: 2
updates:
# Check for updates to GitHub Actions every month.
- package-ecosystem: github-actions
directory: /
schedule:
interval: monthly
# Disable rebasing automatically existing pull requests.
rebase-strategy: "disabled"
# Group updates to a single PR.
groups:
dependencies:
patterns:
- '*'
================================================
FILE: .github/release.yml
================================================
changelog:
exclude:
authors:
- dependabot[bot]
- pre-commit-ci[bot]
- github-actions[bot]
================================================
FILE: .github/workflows/ci_cd.yml
================================================
name: Python CI/CD
on:
workflow_dispatch:
push:
pull_request:
release:
types:
- published
schedule:
# Execute a nightly build at 2am UTC.
- cron: '0 2 * * *'
jobs:
package:
name: Package the project
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.11"
- name: Install Python tools
run: pip install build twine
- name: Create distributions
run: python -m build -o dist/
- name: Inspect dist folder
run: ls -lah dist/
- name: Check wheel's abi and platform tags
run: test $(find dist/ -name *-none-any.whl | wc -l) -gt 0
- name: Run twine check
run: twine check dist/*
- name: Upload artifacts
uses: actions/upload-artifact@v7
with:
path: dist/*
name: dist
test:
name: 'Python${{ matrix.python }}@${{ matrix.os }}'
needs: package
runs-on: ${{ matrix.os }}
env:
PYTHONUTF8: "1"
strategy:
fail-fast: false
matrix:
os:
- ubuntu-latest
- macos-latest
- windows-latest
python:
- "3.10"
- "3.11"
- "3.12"
- "3.13"
steps:
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python }}
- name: Download Python packages
uses: actions/download-artifact@v8
with:
path: dist
name: dist
- name: Install wheel (ubuntu)
if: contains(matrix.os, 'ubuntu')
shell: bash
run: pip install "$(find dist/ -type f -name '*.whl')"
- name: Install wheel (macos|windows)
if: contains(matrix.os, 'macos') || contains(matrix.os, 'windows')
shell: bash
run: pip install "$(find dist/ -type f -name '*.whl')"
- name: Document installed pip packages
shell: bash
run: pip list --verbose
- name: Import the package
run: python -c "import jaxsim"
- uses: actions/checkout@v6
with:
lfs: true
- uses: prefix-dev/setup-pixi@v0.9.5
if: contains(matrix.os, 'ubuntu')
with:
pixi-version: "latest"
frozen: true
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
- name: Run the Python tests
if: |
contains(matrix.os, 'ubuntu') &&
(github.event_name != 'pull_request')
run: pixi run --frozen test --numprocesses auto
env:
# https://github.com/pytest-dev/pytest/issues/7443#issuecomment-656642591
PY_COLORS: "1"
JAX_PLATFORM_NAME: cpu
publish:
name: Publish to PyPI
needs: test
runs-on: ubuntu-latest
permissions:
id-token: write
steps:
- name: Download Python packages
uses: actions/download-artifact@v8
with:
path: dist
name: dist
- name: Inspect dist folder
run: ls -lah dist/
- name: Publish to PyPI
if: |
github.repository == 'gbionics/jaxsim' &&
((github.event_name == 'push' && github.ref == 'refs/heads/main') ||
(github.event_name == 'release'))
uses: pypa/gh-action-pypi-publish@release/v1
with:
skip-existing: true
================================================
FILE: .github/workflows/gpu_benchmark.yml
================================================
name: GPU Benchmarks
on:
push:
branches:
- main
pull_request:
types: [opened, reopened, synchronize]
workflow_dispatch:
schedule:
- cron: "0 0 * * 1" # Run At 00:00 on Monday
permissions:
pull-requests: write
deployments: write
contents: write
jobs:
benchmark:
runs-on: self-hosted
container:
image: ghcr.io/prefix-dev/pixi:0.46.0-noble@sha256:c12bcbe8ba5dfd71867495d3471b95a6993b79cc7de7eafec016f8f59e4e4961
options: --rm --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -e "TERM=xterm-256color"
steps:
- name: Install Git and Git-LFS
run: |
apt update && apt install -y git git-lfs
- name: Checkout repository
uses: actions/checkout@v6
with:
lfs: true
fetch-depth: 0
- name: Fetch pixi.lock from LFS
run: |
git config --global safe.directory /__w/jaxsim/jaxsim
git lfs checkout pixi.lock
- name: Get main branch SHA
id: get-main-branch-sha
run: |
SHA=$(git rev-parse origin/main)
echo "sha=$SHA" >> $GITHUB_OUTPUT
- name: Get benchmark results from main branch
id: cache
uses: actions/cache/restore@v5
with:
path: ./cache
key: ${{ runner.os }}-benchmark
- name: Run benchmark and store result
run: |
pixi run --frozen --environment gpu benchmark --batch-size 128 --benchmark-json output.json
env:
PY_COLORS: "1"
- name: Compare benchmark results with main branch
uses: benchmark-action/github-action-benchmark@v1.22.0
with:
tool: 'pytest'
output-file-path: output.json
external-data-json-path: ./cache/benchmark-data.json
save-data-file: false
fail-on-alert: true
summary-always: true
comment-always: true
alert-threshold: 150%
github-token: ${{ secrets.GITHUB_TOKEN }}
- name: Store benchmark result for main branch
uses: benchmark-action/github-action-benchmark@v1.22.0
if: ${{ github.ref_name == 'main' }}
with:
tool: 'pytest'
output-file-path: output.json
external-data-json-path: ./cache/benchmark-data.json
save-data-file: true
fail-on-alert: false
summary-always: true
comment-always: true
alert-threshold: 150%
github-token: ${{ secrets.GITHUB_TOKEN }}
- name: Publish Benchmark Results to GitHub Pages
uses: benchmark-action/github-action-benchmark@v1.22.0
if: ${{ github.ref_name == 'main' }}
with:
tool: 'pytest'
output-file-path: output.json
benchmark-data-dir-path: "benchmarks"
fail-on-alert: false
github-token: ${{ secrets.GITHUB_TOKEN }}
comment-on-alert: true
summary-always: true
save-data-file: true
alert-threshold: "150%"
auto-push: ${{ github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' }}
- name: Update Benchmark Results cache
uses: actions/cache/save@v5
if: ${{ github.ref_name == 'main' }}
with:
path: ./cache
key: ${{ runner.os }}-benchmark
================================================
FILE: .github/workflows/pixi.yml
================================================
name: Pixi
permissions:
contents: write
pull-requests: write
on:
workflow_dispatch:
schedule:
# Execute at 5am UTC on the first day of the month.
- cron: '0 5 1 * *'
jobs:
pixi-update:
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v6
with:
lfs: true
- name: Set up pixi
uses: prefix-dev/setup-pixi@v0.9.5
with:
run-install: false
- name: Install pixi-diff-to-markdown
run: pixi global install pixi-diff-to-markdown
- name: Update pixi lockfile and generate diff
run: |
set -o pipefail
pixi update --json | pixi exec pixi-diff-to-markdown --explicit-column > diff.md
- name: Test project against updated pixi
run: pixi run --environment default test
env:
PY_COLORS: "1"
JAX_PLATFORM_NAME: cpu
- name: Commit and push changes
run: echo "BRANCH_NAME=update-pixi-$(date +'%Y%m%d%H%M%S')" >> $GITHUB_ENV
- name: Create pull request
uses: peter-evans/create-pull-request@v8
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: Update `pixi.lock`
title: Update `pixi` lockfile
body-path: diff.md
branch: ${{ env.BRANCH_NAME }}
base: main
labels: pixi
add-paths: pixi.lock
delete-branch: true
committer: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
author: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
================================================
FILE: .github/workflows/read_the_docs.yml
================================================
name: Read the Docs PR
on:
pull_request_target:
types:
- opened
permissions:
pull-requests: write
jobs:
documentation-links:
runs-on: ubuntu-latest
steps:
- uses: readthedocs/actions/preview@v1
with:
project-slug: "jaxsim"
project-language: ""
================================================
FILE: .gitignore
================================================
# IDEs
.idea*
.vscode/
# Matlab
*.m~
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
docs/_collections/
docs/modules/_autosummary/
docs/modules/generated
docs/sg_execution_times.rst
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# dynamic version
src/jaxsim/_version.py
# ruff
.ruff_cache/
# pixi environments
.pixi
# data
.mp4
.png
================================================
FILE: .pre-commit-config.yaml
================================================
ci:
autofix_prs: false
autoupdate_schedule: quarterly
submodules: false
default_language_version:
python: python3
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-ast
- id: check-merge-conflict
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-toml
- id: check-added-large-files
args: ["--maxkb=2000"]
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 26.3.1
hooks:
- id: black
args: ["--check", "--diff"]
- repo: https://github.com/pycqa/isort
rev: 8.0.1
hooks:
- id: isort
args: ["--check", "--diff"]
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
- id: rst-backticks
- id: rst-directive-colons
- id: rst-inline-touching-normal
- repo: https://github.com/codespell-project/codespell
rev: v2.4.2
hooks:
- id: codespell
args: ["-S", "*.lock"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.9
hooks:
- id: ruff
- repo: https://github.com/kynan/nbstripout
rev: 0.9.1
hooks:
- id: nbstripout
================================================
FILE: .readthedocs.yaml
================================================
version: "2"
build:
os: ubuntu-24.04
tools:
python: "mambaforge-23.11"
conda:
environment: environment.yml
python:
install:
- method: pip
path: .
sphinx:
configuration: docs/conf.py
formats: all
================================================
FILE: CITATION.cff
================================================
cff-version: 1.2.0
title: JaxSim
message: "If you use this software, please cite the paper."
type: software
authors:
- family-names: Ferretti
given-names: Filippo Luca
affiliation: "Generative Bionics"
- family-names: Ferigo
given-names: Diego
affiliation: "Robotics & AI Institute"
- family-names: Croci
given-names: Alessandro
affiliation: "NEURA Robotics"
- family-names: Sartore
given-names: Carlotta
affiliation: "Generative Bionics"
- family-names: Younis
given-names: Omar G.
affiliation: "Quebec AI Institute"
- family-names: Traversaro
given-names: Silvio
affiliation: "Generative Bionics"
- family-names: Pucci
given-names: Daniele
affiliation: "Generative Bionics"
repository-code: "https://github.com/gbionics/jaxsim"
license: BSD-3-Clause
preferred-citation:
type: article
title: "Contact-Aware Morphology Optimization via Physically Consistent Differentiable Simulation"
authors:
- family-names: Ferretti
given-names: Filippo Luca
- family-names: Ferigo
given-names: Diego
- family-names: Croci
given-names: Alessandro
- family-names: Sartore
given-names: Carlotta
- family-names: Younis
given-names: Omar G.
- family-names: Traversaro
given-names: Silvio
- family-names: Pucci
given-names: Daniele
journal: "IEEE Robotics and Automation Letters"
year: 2026
doi: "10.1109/LRA.2026.3678125"
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to JAXsim :rocket:
Hello Contributor,
We're thrilled that you're considering contributing to JAXsim!
Here's a brief guide to help you seamlessly become a part of our project.
## Development Environment :hammer_and_wrench:
Make sure your development environment is set up.
Follow the installation instructions in the [README](./README.md) to get JAXsim and its dependencies up and running.
To ensure consistency and maintain code quality, we recommend using the pre-commit hook with the following configuration.
This will help catch issues before they become a part of the project.
### Setting Up Pre-commit Hook :fishing_pole_and_fish:
`pre-commit` is a tool that manages pre-commit hooks for your project.
It will run checks on your code before you commit it, ensuring that it meets the project's standards.
First, install `pre-commit` if you haven't already:
```bash
pip install pre-commit
```
Then, run the following command to install the hooks:
```bash
pre-commit install
```
### Using Pre-commit Hook :vertical_traffic_light:
Before making any commits, the pre-commit hook will automatically run.
If it finds any issues, it will prevent the commit and provide instructions on how to fix them.
To get your commit through without fixing the issues, use the `--no-verify` flag:
```bash
git commit -m "Your commit message" --no-verify
```
To manually run the pre-commit hook at any time, use:
```bash
pre-commit run --all-files
```
## Making Changes :construction:
Before submitting a pull request, create an issue to discuss your changes if major changes are involved.
This helps us understand your needs and provide feedback.
Clearly describe your pull request, referencing any related issues.
Follow the [PEP 8](https://peps.python.org/pep-0008/) style guide and include relevant tests.
## Testing :test_tube:
Your code will be tested with the CI/CD pipeline before merging.
Feel free to add new ones or update the existing tests in the [workflows](./.github/workflows) folder to cover your changes.
## Documentation :book:
Update the documentation in the [docs](./docs) folder and the [README](./README.md) to reflect your changes, if necessary.
There is no need to build the documentation locally; it will be automatically built and deployed with your pull request, where a preview link will be provided.
## Code Review :eyes:
Expect feedback during the code review process.
Address comments and make necessary changes.
This collaboration ensures quality.
Please keep the commit history clean, or squash commits if necessary.
## License :scroll:
JAXsim is under the [BSD 3-Clause License](./LICENSE).
By contributing, you agree to the same license.
Thank you for contributing to JAXsim! Your efforts are appreciated.
================================================
FILE: LICENSE
================================================
BSD 3-Clause License
Copyright (c) 2022, Artificial and Mechanical Intelligence
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: README.md
================================================
# JaxSim
**JaxSim** is a **differentiable physics engine** built with JAX, tailored for co-design and robotic learning applications.
## Features
- Physically consistent differentiability w.r.t. hardware parameters.
- Closed chain dynamics support.
- Reduced-coordinate physics engine for **fixed-base** and **floating-base** robots.
- Fully Python-based, leveraging [JAX][jax] following a functional programming paradigm.
- Seamless execution on CPUs, GPUs, and TPUs.
- Supports JIT compilation and automatic vectorization for high performance.
- Compatible with SDF models and URDF (via [sdformat][sdformat] conversion).
> [!WARNING]
> This project is still experimental. APIs may change between releases without notice.
> [!NOTE]
> JaxSim currently focuses on locomotion applications.
> Only contacts between bodies and smooth ground surfaces are supported.
## How to use it
```python
import pathlib
import icub_models
import jax.numpy as jnp
import jaxsim.api as js
# Load the iCub model
model_path = icub_models.get_model_file("iCubGazeboV2_5")
joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',
'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',
'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch',
'r_ankle_roll')
# Build and reduce the model
model_description = pathlib.Path(model_path)
full_model = js.model.JaxSimModel.build_from_model_description(
model_description=model_description,
)
model = js.model.reduce(model=full_model, considered_joints=joints)
# Get the number of degrees of freedom
ndof = model.dofs()
# Initialize data and simulation
# Note that the default data representation is mixed velocity representation
data = js.data.JaxSimModelData.build(
model=model, base_position=jnp.array([0.0, 0.0, 1.0])
)
T = jnp.arange(start=0, stop=1.0, step=model.time_step)
tau = jnp.zeros(ndof)
# Simulate
for _ in T:
data = js.model.step(
model=model, data=data, link_forces=None, joint_force_references=tau
)
```
Check the example folder for additional use cases!
[jax]: https://github.com/google/jax/
[sdformat]: https://github.com/gazebosim/sdformat
[notation]: https://research.tue.nl/en/publications/multibody-dynamics-notation-version-2
[passive_viewer_mujoco]: https://mujoco.readthedocs.io/en/stable/python.html#passive-viewer
## Installation
With conda
You can install the project using [`conda`][conda] as follows:
```bash
conda install jaxsim -c conda-forge
```
GPU support for JAX will be automatically installed if a compatible GPU is detected.
With pixi
> ### Note
> The minimum version of `pixi` required is `0.39.0`.
Since the `pixi.lock` file is stored using Git LFS, make sure you have [Git LFS](https://github.com/git-lfs/git-lfs/blob/main/INSTALLING.md) installed and properly configured on your system before installation. After cloning the repository, run:
```bash
git lfs install && git lfs pull
```
This ensures all LFS-tracked files are properly downloaded before you proceed with the installation.
You can add the `jaxsim` dependency in your [`pixi`][pixi] project as follows:
```bash
pixi add jaxsim
```
If you are on Linux and you want to use a `cuda`-powered version of `jax`, remember to add the appropriate line in the [`system-requirements`](https://pixi.sh/latest/reference/pixi_manifest/#the-system-requirements-table) table, i.e. adding
~~~toml
[system-requirements]
cuda = "13"
~~~
if you are using a `pixi.toml` file or
~~~toml
[tool.pixi.system-requirements]
cuda = "13"
~~~
if you are using a `pyproject.toml` file.
With pip
You can install the project using [`pypa/pip`][pip], preferably in a [virtual environment][venv], as follows:
```bash
pip install jaxsim
```
Check [`pyproject.toml`](pyproject.toml) for the complete list of optional dependencies.
You can obtain a full installation using `jaxsim[all]`.
If you need URDF support, follow the [official instructions](https://gazebosim.org/docs) to install Gazebo Sim on your operating system,
making sure to obtain `sdformat ≥ 13.0` and `gz-tools ≥ 2.0`.
You don't need to install the entire Gazebo Sim suite.
For example, on Ubuntu, it is sufficient to install the `libsdformat*` and `gz-tools2` packages.
If you need GPU support, follow the official [installation instructions][jax_gpu] of JAX.
Contributors installation (with conda)
If you want to contribute to the project, we recommend creating the following `jaxsim` conda environment first:
```bash
conda env create -f environment.yml
```
Then, activate the environment and install the project in editable mode:
```bash
conda activate jaxsim
pip install --no-deps -e .
```
Contributors installation (with pixi)
> ### Note
> The minimum version of `pixi` required is `0.39.0`.
Since the `pixi.lock` file is stored using Git LFS, make sure you have [Git LFS](https://github.com/git-lfs/git-lfs/blob/main/INSTALLING.md) installed and properly configured on your system before installation. After cloning the repository, run:
```bash
git lfs install && git lfs pull
```
This ensures all LFS-tracked files are properly downloaded before you proceed with the installation.
You can install the default dependencies of the project using [`pixi`][pixi] as follows:
```bash
pixi install
```
See `pixi task list` for a list of available tasks.
[conda]: https://anaconda.org/
[pip]: https://github.com/pypa/pip/
[pixi]: https://pixi.sh/
[venv]: https://docs.python.org/3/tutorial/venv.html
[jax_gpu]: https://github.com/google/jax/#installation
## Documentation
The JaxSim API documentation is available at [jaxsim.readthedocs.io][readthedocs].
[readthedocs]: https://jaxsim.readthedocs.io/
## Additional features
Jaxsim can also be used as a multi-body dynamics library! With full support for automatic differentiation of RBDAs (forwards and reverse mode) and automatic differentiation against both kinematic and dynamic parameters.
### Using JaxSim as a multibody dynamics library
```python
import pathlib
import icub_models
import jax.numpy as jnp
import jaxsim.api as js
# Load the iCub model
model_path = icub_models.get_model_file("iCubGazeboV2_5")
joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',
'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',
'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch',
'r_ankle_roll')
# Build and reduce the model
model_description = pathlib.Path(model_path)
full_model = js.model.JaxSimModel.build_from_model_description(
model_description=model_description,
)
model = js.model.reduce(model=full_model, considered_joints=joints)
# Initialize model data
data = js.data.JaxSimModelData.build(
model=model,
base_position=jnp.array([0.0, 0.0, 1.0]),
)
# Frame and dynamics computations
frame_index = js.frame.name_to_idx(model=model, frame_name="l_foot")
# Frame transformation
W_H_F = js.frame.transform(
model=model, data=data, frame_index=frame_index
)
# Frame Jacobian
W_J_F = js.frame.jacobian(
model=model, data=data, frame_index=frame_index
)
# Dynamics properties
M = js.model.free_floating_mass_matrix(model=model, data=data) # Mass matrix
h = js.model.free_floating_bias_forces(model=model, data=data) # Bias forces
g = js.model.free_floating_gravity_forces(model=model, data=data) # Gravity forces
C = js.model.free_floating_coriolis_matrix(model=model, data=data) # Coriolis matrix
# Print dynamics results
print(f"{M.shape=} \n{h.shape=} \n{g.shape=} \n{C.shape=}")
```
## Credits
The RBDAs are based on the theory of the [Rigid Body Dynamics Algorithms][RBDA]
book by Roy Featherstone.
The algorithms and some simulation features were inspired by its accompanying [code][spatial_v2].
[RBDA]: https://link.springer.com/book/10.1007/978-1-4899-7560-7
[spatial_v2]: http://royfeatherstone.org/spatial/index.html#spatial-software
The development of JaxSim started in late 2021, inspired by early versions of [`google/brax`][brax].
At that time, Brax was implemented in maximal coordinates, and we wanted a physics engine in reduced coordinates.
We are grateful to the Brax team for their work and for showing the potential of [JAX][jax] in this field.
Brax v2 was later implemented with reduced coordinates, following an approach comparable to JaxSim.
The development then shifted to [MJX][mjx], which provides a JAX-based implementation of the Mujoco APIs.
The main differences between MJX/Brax and JaxSim are as follows:
- JaxSim supports out-of-the-box all SDF models with [Pose Frame Semantics][PFS].
- JaxSim only supports collisions between points rigidly attached to bodies and a compliant ground surface.
[brax]: https://github.com/google/brax
[mjx]: https://mujoco.readthedocs.io/en/3.0.0/mjx.html
[PFS]: http://sdformat.org/tutorials?tut=pose_frame_semantics
## Contributing
We welcome contributions from the community.
Please read the [contributing guide](./CONTRIBUTING.md) to get started.
## Citing
If you use JaxSim in your work, please cite the following paper:
```bibtex
@article{ferretti_contact_aware_2026,
author = {Filippo Luca Ferretti and Diego Ferigo and Alessandro Croci and Carlotta Sartore and Omar G. Younis and Silvio Traversaro and Daniele Pucci},
title = {Contact-Aware Morphology Optimization via Physically Consistent Differentiable Simulation},
journal = {IEEE Robotics and Automation Letters},
year = {2026},
doi = {10.1109/LRA.2026.3678125}
}
```
## People
| Authors | Maintainer |
|:------:|:-----------:|
| [][df] [][ff] | [][ff] |
[df]: https://github.com/diegoferigo
[ff]: https://github.com/flferretti
## License
[BSD3](https://choosealicense.com/licenses/bsd-3-clause/)
================================================
FILE: docs/Makefile
================================================
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
SPHINXPROJ = JAXsim
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
================================================
FILE: docs/conf.py
================================================
# Configuration file for the Sphinx documentation builder.
import os
import sys
if os.environ.get("READTHEDOCS"):
checkout_name = os.path.basename(os.path.dirname(os.path.realpath(__file__)))
os.environ["CONDA_PREFIX"] = os.path.realpath(
os.path.join("..", "..", "conda", checkout_name)
)
import jaxsim
# -- Version information
sys.path.insert(0, os.path.abspath("."))
sys.path.insert(0, os.path.abspath("../"))
sys.path.insert(0, os.path.abspath("../../"))
module_path = os.path.abspath("../src/")
sys.path.insert(0, module_path)
__version__ = jaxsim._version.__version__
# -- Project information
project = "JAXsim"
copyright = "2022, Artificial and Mechanical Intelligence"
author = "Artificial and Mechanical Intelligence"
release = version = __version__
# -- General configuration
extensions = [
"sphinx.ext.duration",
"sphinx.ext.doctest",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinx.ext.mathjax",
"sphinx.ext.ifconfig",
"sphinx.ext.viewcode",
"sphinx_rtd_theme",
"sphinx.ext.napoleon",
"sphinx_autodoc_typehints",
"sphinx_multiversion",
"myst_nb",
"sphinx_gallery.gen_gallery",
"sphinxcontrib.collections",
"sphinx_design",
]
# -- Options for intersphinx extension
language = "en"
html_theme = "sphinx_book_theme"
templates_path = ["_templates"]
html_title = f"JAXsim {version}"
master_doc = "index"
autodoc_typehints_format = "short"
autodoc_typehints = "description"
autosummary_generate = True
epub_show_urls = "footnote"
# Enable postponed evaluation of annotations (PEP 563)
autodoc_type_aliases = {
"jaxsim.typing.PyTree": "jaxsim.typing.PyTree",
"jaxsim.typing.Vector": "jaxsim.typing.Vector",
"jaxsim.typing.Matrix": "jaxsim.typing.Matrix",
"jaxsim.typing.Array": "jaxsim.typing.Array",
"jaxsim.typing.Int": "jaxsim.typing.Int",
"jaxsim.typing.Bool": "jaxsim.typing.Bool",
"jaxsim.typing.Float": "jaxsim.typing.Float",
"jaxsim.typing.ScalarLike": "jaxsim.typing.ScalarLike",
"jaxsim.typing.ArrayLike": "jaxsim.typing.ArrayLike",
"jaxsim.typing.VectorLike": "jaxsim.typing.VectorLike",
"jaxsim.typing.MatrixLike": "jaxsim.typing.MatrixLike",
"jaxsim.typing.IntLike": "jaxsim.typing.IntLike",
"jaxsim.typing.BoolLike": "jaxsim.typing.BoolLike",
"jaxsim.typing.FloatLike": "jaxsim.typing.FloatLike",
}
# -- Options for sphinx-collections
collections = {
"examples": {"driver": "copy_folder", "source": "../examples/", "ignore": "assets"}
}
# -- Options for sphinx-gallery ----------------------------------------------
sphinx_gallery_conf = {
"examples_dirs": "../examples",
"gallery_dirs": "../generated_examples/",
"doc_module": "jaxsim",
}
# -- Options for myst -------------------------------------------------------
myst_enable_extensions = [
"amsmath",
"dollarmath",
]
nb_execution_mode = "auto"
nb_execution_raise_on_error = True
nb_render_image_options = {
"scale": "60",
}
nb_execution_timeout = 180
source_suffix = [".rst", ".md", ".ipynb"]
# Ignore header warnings
suppress_warnings = ["myst.header"]
================================================
FILE: docs/examples.rst
================================================
.. _collections:
Example Notebooks
=================
.. toctree::
:glob:
:hidden:
:maxdepth: 1
_collections/examples/README.md
.. raw:: html
.. only:: html
:doc:`_collections/examples/jaxsim_as_physics_engine`
.. raw:: html
JaxSim as a hardware-accelerated parallel physics engine
.. only:: html
:doc:`_collections/examples/jaxsim_as_physics_engine_advanced`
.. raw:: html
JaxSim as a hardware-accelerated parallel physics engine [Advanced]
.. only:: html
:doc:`_collections/examples/jaxsim_as_multibody_dynamics_library`
.. raw:: html
JaxSim as a multibody dynamics library
.. only:: html
:doc:`_collections/examples/jaxsim_for_robot_controllers`
.. raw:: html
JaxSim for developing closed-loop robot controllers
================================================
FILE: docs/guide/configuration.rst
================================================
Configuration
=============
JaxSim utilizes environment variables for application configuration. Below is a detailed overview of the various configuration categories and their respective variables.
Collision Dynamics
~~~~~~~~~~~~~~~~~~
Environment variables starting with ``JAXSIM_COLLISION_`` are used to configure collision dynamics. The available variables are:
- ``JAXSIM_COLLISION_SPHERE_POINTS``: Specifies the number of collision points to approximate the sphere.
*Default:* ``50``.
- ``JAXSIM_COLLISION_MESH_ENABLED``: Enables or disables mesh-based collision detection.
*Default:* ``False``.
- ``JAXSIM_COLLISION_USE_BOTTOM_ONLY``: Limits collision detection to only the bottom half of the box or sphere.
*Default:* ``False``.
.. note::
The bottom half is defined as the half of the box or sphere with the lowest z-coordinate in the collision link frame.
Testing
~~~~~~~
For testing configurations, environment variables beginning with ``JAXSIM_TEST_`` are used. The following variables are available:
- ``JAXSIM_TEST_SEED``: Defines the seed for the random number generator.
*Default:* ``0``.
- ``JAXSIM_TEST_AD_ORDER``: Specifies the gradient order for automatic differentiation tests.
*Default:* ``1``.
- ``JAXSIM_TEST_FD_STEP_SIZE``: Sets the step size for finite difference tests.
*Default:* the cube root of the machine epsilon.
Joint Dynamics
~~~~~~~~~~~~~~
Joint dynamics are configured using environment variables starting with ``JAXSIM_JOINT_``. Available variables include:
- ``JAXSIM_JOINT_POSITION_LIMIT_DAMPER``: Overrides the damper value for joint position limits of the SDF model.
- ``JAXSIM_JOINT_POSITION_LIMIT_SPRING``: Overrides the spring value for joint position limits of the SDF model.
Logging and Exceptions
~~~~~~~~~~~~~~~~~~~~~~
The logging and exceptions configurations is controlled by the following environment variables:
- ``JAXSIM_LOGGING_LEVEL``: Determines the logging level.
*Default:* ``DEBUG`` for development, ``WARNING`` for production.
- ``JAXSIM_ENABLE_EXCEPTIONS``: Enables the runtime checks and exceptions. Note that enabling exceptions might lead to device-to-host transfer of data, increasing the computational time required.
*Default:* ``False``.
.. note::
Runtime exceptions are disabled by default on TPU.
================================================
FILE: docs/guide/install.rst
================================================
Installation
============
.. _installation:
Prerequisites
-------------
JAXsim requires Python 3.11 or later.
Basic Installation
------------------
You can install the project with using `conda`_:
.. code-block:: bash
conda install jaxsim -c conda-forge
Alternatively, you can use `pypa/pip`_, preferably in a `virtual environment`_:
.. code-block:: bash
pip install jaxsim
Have a look to `pyproject.toml`_ for a complete list of optional dependencies.
You can install all by using ``pip install "jaxsim[all]"``.
.. note::
If you need GPU support, please follow the official `installation instruction`_ of JAX.
.. _conda: https://anaconda.org/
.. _pyproject.toml: https://github.com/gbionics/jaxsim/blob/main/pyproject.toml
.. _pypa/pip: https://github.com/pypa/pip/
.. _virtual environment: https://docs.python.org/3.8/tutorial/venv.html
.. _installation instruction: https://github.com/google/jax/#installation
================================================
FILE: docs/index.rst
================================================
JAXsim
#######
A scalable physics engine and multibody dynamics library implemented with JAX. With JIT batteries 🔋
.. note::
This simulator currently focuses on locomotion applications. Only contacts with ground are supported.
Features
--------
.. grid::
.. grid-item::
:columns: 12 12 12 6
.. card:: Performance
:class-card: sd-border-0
:shadow: none
:class-title: sd-fs-5
.. div:: sd-font-normal
Physics engine in reduced coordinates implemented with JAX_.
Compatibility with JIT compilation for increased performance and transparent support to execute logic on CPUs, GPUs, and TPUs.
Parallel multi-body simulations on hardware accelerators for significantly increased throughput
.. grid-item::
:columns: 12 12 12 6
.. card:: Model Parsing
:class-card: sd-border-0
:shadow: none
:class-title: sd-fs-5
.. div:: sd-font-normal
Support for SDF models (and, upon conversion, URDF models). Revolute, prismatic, and fixed joints supported.
.. grid-item::
:columns: 12 12 12 6
.. card:: Automatic Differentiation
:class-card: sd-border-0
:shadow: none
:class-title: sd-fs-5
.. div:: sd-font-normal
Support for automatic differentiation of rigid body dynamics algorithms (RBDAs) for model-based robotics research.
Soft contacts model supporting full friction cone and sticking / slipping transition.
.. grid-item::
:columns: 12 12 12 6
.. card:: Complex Dynamics
:class-card: sd-border-0
:shadow: none
:class-title: sd-fs-5
.. div:: sd-font-normal
JAXsim provides a variety of integrators for the simulation of multibody dynamics, including RK4, Heun, Euler, and more.
Support of `multiple velocities representations `_.
----
.. toctree::
:hidden:
guide/install
guide/configuration
examples
.. toctree::
:hidden:
:maxdepth: 2
:caption: JAXsim API
modules/api
modules/math
modules/mujoco
modules/parsers
modules/rbda
modules/typing
modules/utils
Examples
--------
Explore and learn how to use the library through practical demonstrations available in the `examples `__ folder.
Credits
-------
The physics module of JAXsim is based on the theory of the `Rigid Body Dynamics Algorithms `_ book by Roy Featherstone.
We structured part of our logic following its accompanying `code `_.
The physics engine is developed entirely in Python using JAX_.
The inspiration for developing JAXsim originally stemmed from early versions of Brax_.
Here below we summarize the differences between the projects:
- JAXsim simulates multibody dynamics in reduced coordinates, while :code:`brax v1` uses maximal coordinates.
- The new v2 APIs of brax (and the new MJX_) were then implemented in reduced coordinates, following an approach comparable to JAXsim, with major differences in contact handling.
- The rigid-body algorithms used in JAXsim allow to efficiently compute quantities based on the Euler-Poincarè
formulation of the equations of motion, necessary for model-based robotics research.
- JAXsim supports SDF (and, indirectly, URDF) models, assuming the model is described with the
recent `Pose Frame Semantics `_.
- Contrarily to brax, JAXsim only supports collision detection between bodies and a compliant ground surface.
- The RBDAs of JAXsim support automatic differentiation, but this functionality has not been thoroughly tested.
People
------
Authors
'''''''
`Diego Ferigo `_
`Filippo Luca Ferretti `_
Maintainers
'''''''''''
`Filippo Luca Ferretti `_
`Alessandro Croci `_
License
-------
`BSD3 `_
.. _Brax: https://github.com/google/brax
.. _MJX: https://mujoco.readthedocs.io/en/3.0.0/mjx.html
.. _JAX: https://github.com/google/jax
================================================
FILE: docs/make.bat
================================================
@ECHO OFF
pushd %~dp0
REM
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
set SPHINXPROJ=JaxSim
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd
================================================
FILE: docs/modules/api.rst
================================================
Functional API
==============
.. currentmodule:: jaxsim.api
.. autosummary::
:toctree: _autosummary
model
data
contact
kin_dyn_parameters
integrators
joint
link
frame
com
ode
references
actuation_model
common
Model
~~~~~
.. automodule:: jaxsim.api.model
:members:
:no-index:
.. automodule:: jaxsim.api.actuation_model
:members:
:no-index:
Data
~~~~
.. automodule:: jaxsim.api.data
:members:
:no-index:
Contact
~~~~~~~
.. automodule:: jaxsim.api.contact
:members:
:no-index:
KinDynParameters
~~~~~~~~~~~~~~~~
.. automodule:: jaxsim.api.kin_dyn_parameters
:members:
:no-index:
Joint
~~~~~
.. automodule:: jaxsim.api.joint
:members:
:no-index:
Link
~~~~~
.. automodule:: jaxsim.api.link
:members:
:no-index:
Frame
~~~~~
.. automodule:: jaxsim.api.frame
:members:
:no-index:
CoM
~~~
.. automodule:: jaxsim.api.com
:members:
:no-index:
Integration
~~~~~~~~~~~
.. automodule:: jaxsim.api.integrators
:members:
:no-index:
.. automodule:: jaxsim.api.ode
:members:
:no-index:
References
~~~~~~~~~~
.. automodule:: jaxsim.api.references
:members:
:no-index:
Common
~~~~~~
.. autoclass:: jaxsim.api.common.VelRepr
:members:
.. autoclass:: jaxsim.api.common.ModelDataWithVelocityRepresentation
:members:
================================================
FILE: docs/modules/math.rst
================================================
Math
====
.. currentmodule:: jaxsim.math
.. automodule:: jaxsim.math.adjoint
:members:
:undoc-members:
.. automodule:: jaxsim.math.cross
:members:
:undoc-members:
.. automodule:: jaxsim.math.inertia
:members:
:undoc-members:
.. automodule:: jaxsim.math.quaternion
:members:
:undoc-members:
.. automodule:: jaxsim.math.rotation
:members:
:undoc-members:
.. automodule:: jaxsim.math.skew
:members:
:undoc-members:
================================================
FILE: docs/modules/mujoco.rst
================================================
MuJoCo Visualizer
==================
JAXsim provides a simple interface with MuJoCo's visualizer. The visualizer is
a separate process that communicates with the main simulation process. This
allows for the simulation to run at full speed while the visualizer can run at
a different frame rate.
.. currentmodule:: jaxsim.mujoco
Loaders
~~~~~~~
.. automodule:: jaxsim.mujoco.loaders
:members:
Model
~~~~~
.. automodule:: jaxsim.mujoco.model
:members:
Visualizer
~~~~~~~~~~
.. automodule:: jaxsim.mujoco.visualizer
:members:
================================================
FILE: docs/modules/parsers.rst
================================================
Parsers
=======
.. automodule:: jaxsim.parsers.descriptions.collision
:members:
.. automodule:: jaxsim.parsers.descriptions.joint
:members:
.. automodule:: jaxsim.parsers.descriptions.link
:members:
.. automodule:: jaxsim.parsers.descriptions.model
:members:
================================================
FILE: docs/modules/rbda.rst
================================================
Rigid Body Dynamics Algorithms
==============================
This module provides a set of algorithms for rigid body dynamics.
.. currentmodule:: jaxsim.rbda
.. autosummary::
:toctree: _autosummary
aba
collidable_points
contacts.soft
contacts.rigid
contacts.relaxed_rigid
crba
forward_kinematics
jacobian
utils
Collision Detection
~~~~~~~~~~~~~~~~~~~
.. automodule:: jaxsim.rbda.collidable_points
:members:
:no-index:
Contact Models
~~~~~~~~~~~~~~
.. automodule:: jaxsim.rbda.contacts.soft
:members:
:no-index:
.. automodule:: jaxsim.rbda.contacts.rigid
:members:
:no-index:
.. automodule:: jaxsim.rbda.contacts.relaxed_rigid
:members:
:no-index:
Utilities
~~~~~~~~~
.. automodule:: jaxsim.rbda.utils
:members:
:no-index:
================================================
FILE: docs/modules/typing.rst
================================================
Typing
======
.. currentmodule:: jaxsim.typing
.. autosummary::
PyTree
Matrix
Bool
Int
Float
Vector
BoolLike
FloatLike
IntLike
ArrayLike
VectorLike
MatrixLike
================================================
FILE: docs/modules/utils.rst
================================================
Utils
=====
.. automodule:: jaxsim.utils
:members:
:inherited-members:
.. autoclass:: jaxsim.utils.JaxsimDataclass
:members:
:inherited-members:
================================================
FILE: environment.yml
================================================
name: jaxsim
channels:
- conda-forge
dependencies:
# ===========================
# Dependencies from setup.cfg
# ===========================
- python >= 3.12.0
- coloredlogs
- jax >= 0.4.34
- jaxlib >= 0.4.34
- jaxlie >= 1.3.0
- jax-dataclasses >= 1.4.0
- optax >= 0.2.3
- pptree
- qpax
- rod >= 0.3.3
- trimesh
- typing_extensions # python<3.12
# ====================================
# Optional dependencies from setup.cfg
# ====================================
# [testing]
- chex
- idyntree >= 12.2.1
- pytest
- pytest-benchmark
- pytest-icdiff
- robot_descriptions >= 1.16.0
- icub-models
# [viz]
- lxml
- mediapy
- mujoco >= 3.0.0
- scipy >= 1.14.0
# ==========================
# Documentation dependencies
# ==========================
- cachecontrol
- filecache
- jinja2
- myst-nb
- pip
- sphinx
- sphinx-autodoc-typehints
- sphinx-book-theme
- sphinx-copybutton
- sphinx-design
- sphinx_fontawesome
- sphinx-gallery
- sphinx-jinja2-compat
- sphinx-multiversion
- sphinx_rtd_theme
- sphinx-toolbox
- icub-models
# ========================================
# Other dependencies for GitHub Codespaces
# ========================================
- ipython
- pip:
- sphinx-collections # TODO (flferretti): PR to conda-forge
================================================
FILE: examples/.gitattributes
================================================
# GitHub syntax highlighting
pixi.lock linguist-language=YAML
================================================
FILE: examples/.gitignore
================================================
# pixi environments
.pixi
================================================
FILE: examples/README.md
================================================
# JaxSim Examples
This folder contains Jupyter notebooks that demonstrate the practical usage of JaxSim.
## Featured examples
| Notebook | Google Colab | Description |
| :--- | :---: | :--- |
| [`jaxsim_as_multibody_dynamics_library`](./jaxsim_as_multibody_dynamics_library.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_as_multibody_dynamics] | An example demonstrating how to use JaxSim as a multibody dynamics library. |
| [`jaxsim_as_physics_engine.ipynb`](./jaxsim_as_physics_engine.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_as_physics_engine] | An example demonstrating how to simulate vectorized models in parallel. |
| [`jaxsim_as_physics_engine_advanced.ipynb`](./jaxsim_as_physics_engine_advanced.ipynb) | [![Open In Colab][colab_badge]][jaxsim_as_physics_engine_advanced] | An example showcasing advanced JaxSim usage, such as customizing the integrator, contact model, and more. |
| [`jaxsim_for_robot_controllers.ipynb`](./jaxsim_for_robot_controllers.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_closed_loop] | A basic example showing how to simulate a PD controller with gravity compensation for a 2-DOF cart-pole. |
[colab_badge]: https://colab.research.google.com/assets/colab-badge.svg
[ipynb_jaxsim_closed_loop]: https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_for_robot_controllers.ipynb
[ipynb_jaxsim_as_physics_engine]: https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_as_physics_engine.ipynb
[jaxsim_as_physics_engine_advanced]: https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_as_physics_engine_advanced.ipynb
[ipynb_jaxsim_as_multibody_dynamics]: https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_as_multibody_dynamics_library.ipynb
## How to run the examples
You can run the JaxSim examples with hardware acceleration in two ways.
### Option 1: Google Colab (recommended)
The easiest way is to use the provided Google Colab links to run the notebooks in a hosted environment
with no setup required.
### Option 2: Local execution with `pixi`
To run the examples locally, first install `pixi` following the [official documentation][pixi_installation]:
[pixi_installation]: https://pixi.sh/#installation
```bash
curl -fsSL https://pixi.sh/install.sh | bash
```
Then, from the repository's root directory, execute the example notebooks using:
```bash
pixi run examples
```
This command will automatically handle all necessary dependencies and run the examples in a self-contained environment.
================================================
FILE: examples/assets/build_cartpole_urdf.py
================================================
import os
if "ROD_LOGGING_LEVEL" not in os.environ:
os.environ["ROD_LOGGING_LEVEL"] = "WARNING"
import numpy as np
import rod.kinematics.tree_transforms
from rod.builder import primitives
if __name__ == "__main__":
# ================
# Model parameters
# ================
# Rail parameters.
rail_height = 1.2
rail_length = 5.0
rail_radius = 0.005
rail_mass = 5.0
# Cart parameters.
cart_mass = 1.0
cart_size = (0.1, 0.2, 0.05)
# Pole parameters.
pole_mass = 0.5
pole_length = 1.0
pole_radius = 0.005
# ========================
# Create the link builders
# ========================
rail_builder = primitives.CylinderBuilder(
name="rail",
mass=rail_mass,
radius=rail_radius,
length=rail_length,
)
cart_builder = primitives.BoxBuilder(
name="cart",
mass=cart_mass,
x=cart_size[0],
y=cart_size[1],
z=cart_size[2],
)
pole_builder = primitives.CylinderBuilder(
name="pole",
mass=pole_mass,
radius=pole_radius,
length=pole_length,
)
# =================
# Create the joints
# =================
world_to_rail = rod.Joint(
name="world_to_rail",
type="fixed",
parent="world",
child=rail_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(
relative_to="world",
),
)
linear = rod.Joint(
name="linear",
type="prismatic",
parent=rail_builder.name,
child=cart_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(
relative_to=rail_builder.name,
pos=np.array([0, 0, rail_height]),
),
axis=rod.Axis(
xyz=rod.Xyz(xyz=[0, 1, 0]),
limit=rod.Limit(
upper=(rail_length / 2 - cart_size[1] / 2),
lower=-(rail_length / 2 - cart_size[1] / 2),
effort=500.0,
velocity=10.0,
),
),
)
pivot = rod.Joint(
name="pivot",
type="continuous",
parent=cart_builder.name,
child=pole_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(
relative_to=cart_builder.name,
),
axis=rod.Axis(
xyz=rod.Xyz(xyz=[1, 0, 0]),
limit=rod.Limit(),
),
)
# ================
# Create the links
# ================
rail_elements_pose = primitives.PrimitiveBuilder.build_pose(
pos=np.array([0, 0, rail_height]),
rpy=np.array([np.pi / 2, 0, 0]),
)
rail = (
rail_builder.build_link(
name=rail_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(
relative_to=world_to_rail.name,
),
)
.add_inertial(pose=rail_elements_pose)
.add_visual(pose=rail_elements_pose)
.add_collision(pose=rail_elements_pose)
.build()
)
cart = (
cart_builder.build_link(
name=cart_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(relative_to=linear.name),
)
.add_inertial()
.add_visual()
.add_collision()
.build()
)
pole_elements_pose = primitives.PrimitiveBuilder.build_pose(
pos=np.array([0, 0, pole_length / 2]),
)
pole = (
pole_builder.build_link(
name=pole_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(
relative_to=pivot.name,
),
)
.add_inertial(pose=pole_elements_pose)
.add_visual(pose=pole_elements_pose)
.add_collision(pose=pole_elements_pose)
.build()
)
# ===========
# Build model
# ===========
# Create ROD in-memory model.
model = rod.Model(
name="cartpole",
canonical_link=rail.name,
link=[
rail,
cart,
pole,
],
joint=[
world_to_rail,
linear,
pivot,
],
)
# Update the pose elements to be closer to those expected in URDF.
model.switch_frame_convention(
frame_convention=rod.FrameConvention.Urdf, explicit_frames=True
)
# ==============
# Get SDF string
# ==============
# Create the top-level SDF object.
sdf = rod.Sdf(version="1.10", model=model)
# Generate the SDF string.
# sdf_string = sdf.serialize(pretty=True, validate=True)
# ===============
# Get URDF string
# ===============
import rod.urdf.exporter
# Convert the SDF to URDF.
urdf_string = rod.urdf.exporter.UrdfExporter(
pretty=True, indent=" "
).to_urdf_string(sdf=sdf)
# Print the URDF string.
print(urdf_string)
================================================
FILE: examples/assets/cartpole.urdf
================================================
================================================
FILE: examples/jaxsim_as_multibody_dynamics_library.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "DpLq0-lltwZ1"
},
"source": [
"# `JaxSim` as a multibody dynamics library\n",
"\n",
"JaxSim was initially developed as a **hardware-accelerated physics engine**. Over time, it has evolved, adding new features to become a comprehensive **JAX-based multibody dynamics library**.\n",
"\n",
"In this notebook, you'll explore the main APIs for loading robot models and computing key quantities for applications such as control, planning, and more.\n",
"\n",
"A key advantage of JaxSim is its ability to create fully differentiable closed-loop systems, enabling end-to-end optimization. Combined with the flexibility to parameterize model kinematics and dynamics, JaxSim can serve as an excellent playground for robot learning applications.\n",
"\n",
"\n",
" \n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rcEwprINtwZ3"
},
"source": [
"## Prepare environment\n",
"\n",
"First, we need to install the necessary packages and import their resources."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "u4xL7dbBtwZ3",
"outputId": "1a088e28-e005-4910-928c-cb641e589ab5"
},
"outputs": [],
"source": [
"# @title Imports and setup\n",
"from IPython.display import clear_output\n",
"import sys\n",
"\n",
"IS_COLAB = \"google.colab\" in sys.modules\n",
"\n",
"# Install JAX, sdformat, and other notebook dependencies.\n",
"if IS_COLAB:\n",
" !{sys.executable} -m pip install --pre -qU jaxsim\n",
" !{sys.executable} -m pip install robot_descriptions>=1.16.0\n",
" !apt install -qq lsb-release wget gnupg\n",
" !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n",
" !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n",
" !apt -qq update\n",
" !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n",
"\n",
" clear_output()\n",
"\n",
"import os\n",
"import pathlib\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jaxsim.api as js\n",
"import jaxsim.math\n",
"from jaxsim import logging\n",
"from jaxsim import VelRepr\n",
"\n",
"logging.set_logging_level(logging.LoggingLevel.WARNING)\n",
"print(f\"Running on {jax.devices()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fN8Xg4QgtwZ4"
},
"source": [
"## Robot model\n",
"\n",
"JaxSim allows loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files.\n",
"\n",
"In this example, we will use the [ErgoCub][ergocub] humanoid robot model. If you have a URDF/SDF file for your robot that is compatible with [`gazebosim/sdformat`][sdformat_github][1], it should work out-of-the-box with JaxSim.\n",
"\n",
"[sdformat]: http://sdformat.org/\n",
"[urdf]: http://wiki.ros.org/urdf/\n",
"[ergocub]: https://ergocub.eu/\n",
"[sdformat_github]: https://github.com/gazebosim/sdformat\n",
"\n",
"---\n",
"\n",
"[1]: JaxSim validates robot descriptions using the command `gz sdf -p /path/to/file.urdf`. Ensure this command runs successfully on your file.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rB0BFxyPtwZ5"
},
"outputs": [],
"source": [
"# @title Fetch the URDF file\n",
"\n",
"try:\n",
" os.environ[\"ROBOT_DESCRIPTION_COMMIT\"] = \"v0.7.7\"\n",
"\n",
" import robot_descriptions.ergocub_description\n",
"\n",
"finally:\n",
" _ = os.environ.pop(\"ROBOT_DESCRIPTION_COMMIT\", None)\n",
"\n",
"model_description_path = pathlib.Path(\n",
" robot_descriptions.ergocub_description.URDF_PATH.replace(\n",
" \"ergoCubSN002\", \"ergoCubSN001\"\n",
" )\n",
")\n",
"\n",
"clear_output()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jeTUZic8twZ5"
},
"source": [
"### Create the model and its data\n",
"\n",
"The dynamics of a generic floating-base model are governed by the following equations of motion:\n",
"\n",
"$$\n",
"M(\\mathbf{q}) \\dot{\\boldsymbol{\\nu}} + \\mathbf{h}(\\mathbf{q}, \\boldsymbol{\\nu}) = B \\boldsymbol{\\tau} + \\sum_{L_i \\in \\mathcal{L}} J_{W,L_i}^\\top(\\mathbf{q}) \\: \\mathbf{f}_i\n",
".\n",
"$$\n",
"\n",
"Here, the system state is represented by:\n",
"\n",
"- $\\mathbf{q} = ({}^W \\mathbf{p}_B, \\, \\mathbf{s}) \\in \\text{SE}(3) \\times \\mathbb{R}^n$ is the generalized position.\n",
"- $\\boldsymbol{\\nu} = (\\boldsymbol{v}_{W,B}, \\, \\boldsymbol{\\omega}_{W,B}, \\, \\dot{\\mathbf{s}}) \\in \\mathbb{R}^{6+n}$ is the generalized velocity.\n",
"\n",
"The inputs to the system are:\n",
"\n",
"- $\\boldsymbol{\\tau} \\in \\mathbb{R}^n$ are the joint torques.\n",
"- $\\mathbf{f}_i \\in \\mathbb{R}^6$ is the 6D force applied to the link $L_i$.\n",
"\n",
"JaxSim exposes functional APIs to operate over the following two main data structures:\n",
"\n",
"- **`JaxSimModel`** stores all the constant information parsed from the model description.\n",
"- **`JaxSimModelData`** holds the state of model.\n",
"\n",
"Additionally, JaxSim includes a utility class, **`JaxSimModelReferences`**, for managing and manipulating system inputs.\n",
"\n",
"---\n",
"\n",
"This notebook uses the notation summarized in the following report. Please refer to this document if you have any questions or if something is unclear.\n",
"\n",
"> Traversaro and Saccon, **Multibody dynamics notation**, 2019, [URL](https://pure.tue.nl/ws/portalfiles/portal/139293126/A_Multibody_Dynamics_Notation_Revision_2_.pdf)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WYgBAxU0twZ6"
},
"outputs": [],
"source": [
"# Create the model from the model description.\n",
"# JaxSim removes all fixed joints by lumping together their parent and child links.\n",
"full_model = js.model.JaxSimModel.build_from_model_description(\n",
" model_description=model_description_path\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DdaETmDStwZ6"
},
"source": [
"It is often useful to work with only a subset of joints, referred to as the _considered joints_. JaxSim allows to reduce a model so that the computation of the rigid body dynamics quantities is simplified.\n",
"\n",
"By default, the positions of the removed joints are considered to be zero. If this is not the case, the `reduce` function accepts a dictionary `dict[str, float]` to specify custom joint positions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QuhG7Zv5twZ7"
},
"outputs": [],
"source": [
"model = js.model.reduce(\n",
" model=full_model,\n",
" considered_joints=tuple(\n",
" j\n",
" for j in full_model.joint_names()\n",
" # Remove sensor joints.\n",
" if \"camera\" not in j\n",
" # Remove head and hands.\n",
" and \"neck\" not in j\n",
" and \"wrist\" not in j\n",
" and \"thumb\" not in j\n",
" and \"index\" not in j\n",
" and \"middle\" not in j\n",
" and \"ring\" not in j\n",
" and \"pinkie\" not in j\n",
" # Remove upper body.\n",
" and \"torso\" not in j and \"elbow\" not in j and \"shoulder\" not in j\n",
" ),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RLvAit_i2ZiA",
"outputId": "ea3954af-b9b9-46ac-d9cb-20b99b1eac94"
},
"outputs": [],
"source": [
"# Print model quantities.\n",
"print(f\"Model name: {model.name()}\")\n",
"print(f\"Number of links: {model.number_of_links()}\")\n",
"print(f\"Number of joints: {model.number_of_joints()}\")\n",
"\n",
"print()\n",
"print(f\"Links:\\n{model.link_names()}\")\n",
"\n",
"print()\n",
"print(f\"Joints:\\n{model.joint_names()}\")\n",
"\n",
"print()\n",
"print(f\"Frames:\\n{model.frame_names()}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Xp8V5on5twZ8",
"outputId": "cc1564db-ae91-4dba-92c9-b8b87bd65f10"
},
"outputs": [],
"source": [
"# Create a random data object from the reduced model.\n",
"data = js.data.random_model_data(model=model)\n",
"\n",
"# Print the default state.\n",
"W_H_B, s = data.generalized_position\n",
"ν = data.generalized_velocity\n",
"\n",
"print(f\"W_H_B: shape={W_H_B.shape}\\n{W_H_B}\\n\")\n",
"print(f\"s: shape={s.shape}\\n{s}\\n\")\n",
"print(f\"ν: shape={ν.shape}\\n{ν}\\n\") # noqa: RUF001"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "XLx3sv9VtwZ9",
"outputId": "28f5f070-e37e-464e-d84e-2944cfdc28dc"
},
"outputs": [],
"source": [
"# Create a random link forces matrix.\n",
"link_forces = jax.random.uniform(\n",
" minval=-10.0,\n",
" maxval=10.0,\n",
" shape=(model.number_of_links(), 6),\n",
" key=jax.random.PRNGKey(0),\n",
")\n",
"\n",
"# Create a random joint force references vector.\n",
"# Note that these are called 'references' because the actual joint forces that\n",
"# are actuated might differ due to effects like joint friction.\n",
"joint_force_references = jax.random.uniform(\n",
" minval=-10.0, maxval=10.0, shape=(model.dofs(),), key=jax.random.PRNGKey(0)\n",
")\n",
"\n",
"# Create the references object.\n",
"references = js.references.JaxSimModelReferences.build(\n",
" model=model,\n",
" data=data,\n",
" link_forces=link_forces,\n",
" joint_force_references=joint_force_references,\n",
")\n",
"\n",
"print(f\"link_forces: shape={references.link_forces(model=model, data=data).shape}\")\n",
"print(f\"joint_force_references: shape={references.joint_force_references(model=model).shape}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AaG817vP4LfT"
},
"source": [
"## Robot Kinematics\n",
"\n",
"JaxSim offers functional APIs for computing kinematic quantities:\n",
"\n",
"- **`jaxsim.api.model`**: vectorized functions operating on the whole model.\n",
"- **`jaxsim.api.link`**: functions operating on individual links.\n",
"- **`jaxsim.api.frame`**: functions operating on individual frames. \n",
"\n",
"Due to JAX limitations on vectorizable data types, many APIs operate on indices instead of names. Since using indices can be error prone, JaxSim provides conversion functions for both links:\n",
"\n",
"- **jaxsim.api.link.names_to_idxs()**\n",
"- **jaxsim.api.link.idxs_to_names()**\n",
"\n",
"and frames: \n",
"\n",
"- **jaxsim.api.frame.names_to_idxs()**\n",
"- **jaxsim.api.frame.idxs_to_names()**\n",
"\n",
"We recommend using names whenever possible to avoid hard-to-trace errors.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QxImwfZz7pz-"
},
"outputs": [],
"source": [
"# Find the index of a link.\n",
"link_name = \"l_ankle_2\"\n",
"link_index = js.link.name_to_idx(model=model, link_name=link_name)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "C22Iqu2i4G-I",
"outputId": "94376151-177d-410f-f375-b7b8bd080992"
},
"outputs": [],
"source": [
"# @title Link Pose\n",
"\n",
"# Compute its pose w.r.t. the world frame through forward kinematics.\n",
"W_H_L = js.link.transform(model=model, data=data, link_index=link_index)\n",
"\n",
"print(f\"Transform of '{link_name}': shape={W_H_L.shape}\\n{W_H_L}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DnSpE_f97RkX",
"outputId": "a3f6b535-4ae5-49f4-8921-7fe4dda5debb"
},
"outputs": [],
"source": [
"# @title Link 6D Velocity\n",
"\n",
"# JaxSim allows to select the so-called representation of the frame velocity.\n",
"L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Body)\n",
"LW_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Mixed)\n",
"W_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Inertial)\n",
"\n",
"print(f\"Body-fixed velocity L_v_WL={L_v_WL}\")\n",
"print(f\"Mixed velocity: LW_v_WL={LW_v_WL}\")\n",
"print(f\"Inertial-fixed velocity: W_v_WL={W_v_WL}\")\n",
"\n",
"# These can also be computed passing through the link free-floating Jacobian.\n",
"# This type of Jacobian has a input velocity representation that corresponds\n",
"# the velocity representation of ν, and an output velocity representation that\n",
"# corresponds to the velocity representation of the desired 6D velocity.\n",
"\n",
"# You can use the following context manager to easily switch between representations.\n",
"with data.switch_velocity_representation(VelRepr.Body):\n",
"\n",
" # Body-fixed generalized velocity.\n",
" B_ν = data.generalized_velocity\n",
"\n",
" # Free-floating Jacobian accepting a body-fixed generalized velocity and\n",
" # returning an inertial-fixed link velocity.\n",
" W_J_WL_B = js.link.jacobian(\n",
" model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Inertial\n",
" )\n",
"\n",
"# Now the following relation should hold.\n",
"assert jnp.allclose(W_v_WL, W_J_WL_B @ B_ν)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SSoziCShtwZ9"
},
"outputs": [],
"source": [
"# Find the index of a frame.\n",
"frame_name = \"l_foot_front\"\n",
"frame_index = js.frame.name_to_idx(model=model, frame_name=frame_name)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fVp_xP_1twZ9",
"outputId": "cfaa0569-d768-4708-c98c-a5867c056d04"
},
"outputs": [],
"source": [
"# @title Frame Pose\n",
"\n",
"# Compute its pose w.r.t. the world frame through forward kinematics.\n",
"W_H_F = js.frame.transform(model=model, data=data, frame_index=frame_index)\n",
"\n",
"print(f\"Transform of '{frame_name}': shape={W_H_F.shape}\\n{W_H_F}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QqaqxneEFYiW"
},
"outputs": [],
"source": [
"# @title Frame 6D Velocity\n",
"\n",
"# JaxSim allows to select the so-called representation of the frame velocity.\n",
"F_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Body)\n",
"FW_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Mixed)\n",
"W_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Inertial)\n",
"\n",
"print(f\"Body-fixed velocity F_v_WF={F_v_WF}\")\n",
"print(f\"Mixed velocity: FW_v_WF={FW_v_WF}\")\n",
"print(f\"Inertial-fixed velocity: W_v_WF={W_v_WF}\")\n",
"\n",
"# These can also be computed passing through the frame free-floating Jacobian.\n",
"# This type of Jacobian has a input velocity representation that corresponds\n",
"# the velocity representation of ν, and an output velocity representation that\n",
"# corresponds to the velocity representation of the desired 6D velocity.\n",
"\n",
"# You can use the following context manager to easily switch between representations.\n",
"with data.switch_velocity_representation(VelRepr.Body):\n",
"\n",
" # Body-fixed generalized velocity.\n",
" B_ν = data.generalized_velocity\n",
"\n",
" # Free-floating Jacobian accepting a body-fixed generalized velocity and\n",
" # returning an inertial-fixed link velocity.\n",
" W_J_WF_B = js.frame.jacobian(\n",
" model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Inertial\n",
" )\n",
"\n",
"# Now the following relation should hold.\n",
"assert jnp.allclose(W_v_WF, W_J_WF_B @ B_ν)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "d_vp6D74GoVZ",
"outputId": "798b9283-792e-4339-b56c-df2595fac974"
},
"source": [
"## Robot Dynamics\n",
"\n",
"JaxSim provides all the quantities involved in the equations of motion, restated here:\n",
"\n",
"$$\n",
"M(\\mathbf{q}) \\dot{\\boldsymbol{\\nu}} + \\mathbf{h}(\\mathbf{q}, \\boldsymbol{\\nu}) = B \\boldsymbol{\\tau} + \\sum_{L_i \\in \\mathcal{L}} J_{W,L_i}^\\top(\\mathbf{q}) \\: \\mathbf{f}_i\n",
".\n",
"$$\n",
"\n",
"Specifically, it can compute:\n",
"\n",
"- $M(\\mathbf{q}) \\in \\mathbb{R}^{(6+n)\\times(6+n)}$: the mass matrix.\n",
"- $\\mathbf{h}(\\mathbf{q}, \\boldsymbol{\\nu}) \\in \\mathbb{R}^{6+n}$: the vector of bias forces.\n",
"- $B \\in \\mathbb{R}^{(6+n) \\times n}$ the joint selector matrix.\n",
"- $J_{W,L} \\in \\mathbb{R}^{6 \\times (6+n)}$ the Jacobian of link $L$.\n",
"\n",
"Often, for convenience, link Jacobians are stacked together. Since JaxSim efficiently computes the Jacobians for all links, using the stacked version is recommended when needed:\n",
"\n",
"$$\n",
"M(\\mathbf{q}) \\dot{\\boldsymbol{\\nu}} + \\mathbf{h}(\\mathbf{q}, \\boldsymbol{\\nu}) = B \\boldsymbol{\\tau} + J_{W,\\mathcal{L}}^\\top(\\mathbf{q}) \\: \\mathbf{f}_\\mathcal{L}\n",
".\n",
"$$\n",
"\n",
"Furthermore, there are applications that require unpacking the vector of bias forces as follow:\n",
"\n",
"$$\n",
"\\mathbf{h}(\\mathbf{q}, \\boldsymbol{\\nu}) = C(\\mathbf{q}, \\boldsymbol{\\nu}) \\boldsymbol{\\nu} + \\mathbf{g}(\\mathbf{q})\n",
",\n",
"$$\n",
"\n",
"where:\n",
"\n",
"- $\\mathbf{g}(\\mathbf{q}) \\in \\mathbb{R}^{6+n}$: the vector of gravity forces.\n",
"- $C(\\mathbf{q}, \\boldsymbol{\\nu}) \\in \\mathbb{R}^{(6+n)\\times(6+n)}$: the Coriolis matrix.\n",
"\n",
"Here below we report the functions to compute all these quantities. Note that all quantities depend on the active velocity representation of `data`. As it was done for the link velocity, it is possible to change the representation associated to all the computed quantities by operating within the corresponding context manager. Here below we consider the default representation of data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oOKJOVfsH4Ki"
},
"outputs": [],
"source": [
"print(\"Velocity representation of data:\", data.velocity_representation, \"\\n\")\n",
"\n",
"# Compute the mass matrix.\n",
"M = js.model.free_floating_mass_matrix(model=model, data=data)\n",
"print(f\"M: shape={M.shape}\")\n",
"\n",
"# Compute the vector of bias forces.\n",
"h = js.model.free_floating_bias_forces(model=model, data=data)\n",
"print(f\"h: shape={h.shape}\")\n",
"\n",
"# Compute the vector of gravity forces.\n",
"g = js.model.free_floating_gravity_forces(model=model, data=data)\n",
"print(f\"g: shape={g.shape}\")\n",
"\n",
"# Compute the Coriolis matrix.\n",
"C = js.model.free_floating_coriolis_matrix(model=model, data=data)\n",
"print(f\"C: shape={C.shape}\")\n",
"\n",
"# Create a the joint selector matrix.\n",
"B = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T\n",
"print(f\"B: shape={B.shape}\")\n",
"\n",
"# Compute the stacked tensor of link Jacobians.\n",
"J = js.model.generalized_free_floating_jacobian(model=model, data=data)\n",
"print(f\"J: shape={J.shape}\")\n",
"\n",
"# Extract the joint forces from the references object.\n",
"τ = references.joint_force_references(model=model)\n",
"print(f\"τ: shape={τ.shape}\")\n",
"\n",
"# Extract the link forces from the references object.\n",
"f_L = references.link_forces(model=model, data=data)\n",
"print(f\"f_L: shape={f_L.shape}\")\n",
"\n",
"# The following relation should hold.\n",
"assert jnp.allclose(h, C @ ν + g)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "FlNo8dNWKKtu",
"outputId": "313e939b-f88f-4407-c9ee-b5b3b7443061"
},
"source": [
"### Forward Dynamics\n",
"\n",
"$$\n",
"\\dot{\\boldsymbol{\\nu}} = \\text{FD}(\\mathbf{q}, \\boldsymbol{\\nu}, \\boldsymbol{\\tau}, \\mathbf{f}_{\\mathcal{L}})\n",
"$$\n",
"\n",
"JaxSim provides two alternative methods to compute the forward dynamics:\n",
"\n",
"1. Operate on the quantities of the equations of motion.\n",
"2. Call the recursive Articulated Body Algorithm (ABA).\n",
"\n",
"The physics engine provided by JaxSim exploits the efficient calculation of the forward dynamics with ABA for simulating the trajectories of the system dynamics."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LXARuRu1Ly1K"
},
"outputs": [],
"source": [
"ν̇_eom = jnp.linalg.pinv(M) @ (B @ τ - h + jnp.einsum(\"l6g,l6->g\", J, f_L))\n",
"\n",
"v̇_WB, s̈ = js.model.forward_dynamics_aba(\n",
" model=model, data=data, link_forces=f_L, joint_forces=joint_force_references\n",
")\n",
"\n",
"ν̇_aba = jnp.hstack([v̇_WB, s̈])\n",
"print(f\"ν̇: shape={ν̇_aba.shape}\") # noqa: RUF001\n",
"\n",
"# The following relation should hold.\n",
"assert jnp.allclose(ν̇_eom, ν̇_aba)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "g5GOYXDnLySU",
"outputId": "ad4ce77d-d06f-473a-9c32-040680d76aa5"
},
"source": [
"### Inverse Dynamics\n",
"\n",
"$$\n",
"(\\boldsymbol{\\tau}, \\, \\mathbf{f}_B) = \\text{ID}(\\mathbf{q}, \\boldsymbol{\\nu}, \\dot{\\boldsymbol{\\nu}}, \\mathbf{f}_{\\mathcal{L}})\n",
"$$\n",
"\n",
"JaxSim offers two methods to compute inverse dynamics:\n",
"\n",
"- Directly use the quantities from the equations of motion.\n",
"- Use the Recursive Newton-Euler Algorithm (RNEA).\n",
"\n",
"Unlike many other implementations, JaxSim's RNEA for floating-base systems is the true inverse of $\\text{FD}$. It also computes the 6D force applied to the base link that generates the base acceleration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UTae5MjhaP2H"
},
"outputs": [],
"source": [
"f_B, τ_rnea = js.model.inverse_dynamics(\n",
" model=model,\n",
" data=data,\n",
" base_acceleration=v̇_WB,\n",
" joint_accelerations=s̈,\n",
" # To check that f_B works, let's remove the force applied\n",
" # to the base link from the link forces.\n",
" link_forces=f_L.at[0].set(jnp.zeros(6))\n",
")\n",
"\n",
"print(f\"f_B: shape={f_B.shape}\")\n",
"print(f\"τ_rnea: shape={τ_rnea.shape}\")\n",
"\n",
"# The following relations should hold.\n",
"assert jnp.allclose(τ_rnea, τ)\n",
"assert jnp.allclose(f_B, link_forces[0])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gYZ1jK1Neg1H",
"outputId": "0de79770-1e18-4027-bb47-5713bc1b4a72"
},
"source": [
"### Centroidal Dynamics\n",
"\n",
"Centroidal dynamics is a useful simplification often employed in planning and control applications. It represents the dynamics projected onto a mixed frame associated with the center of mass (CoM):\n",
"\n",
"$$\n",
"G = G[W] = ({}^W \\mathbf{p}_{\\text{CoM}}, [W])\n",
".\n",
"$$\n",
"\n",
"The governing equations for centroidal dynamics take into account the 6D centroidal momentum:\n",
"\n",
"$$\n",
"{}_G \\mathbf{h} =\n",
"\\begin{bmatrix}\n",
"{}_G \\mathbf{h}^l \\\\ {}_G \\mathbf{h}^\\omega\n",
"\\end{bmatrix} =\n",
"\\begin{bmatrix}\n",
"m \\, {}^W \\dot{\\mathbf{p}}_\\text{CoM} \\\\ {}_G \\mathbf{h}^\\omega\n",
"\\end{bmatrix}\n",
"\\in \\mathbb{R}^6\n",
".\n",
"$$\n",
"\n",
"The equations of centroidal dynamics can be expressed as:\n",
"\n",
"$$\n",
"{}_G \\dot{\\mathbf{h}} =\n",
"m \\,\n",
"\\begin{bmatrix}\n",
"{}^W \\mathbf{g} \\\\ \\mathbf{0}_3\n",
"\\end{bmatrix} +\n",
"\\sum_{C_i \\in \\mathcal{C}} {}_G \\mathbf{X}^{C_i} \\, {}_{C_i} \\mathbf{f}_i\n",
".\n",
"$$\n",
"\n",
"While centroidal dynamics can function independently by considering the total mass $m \\in \\mathbb{R}$ of the robot and the transformations for 6D contact forces ${}_G \\mathbf{X}^{C_i}$ corresponding to the pose ${}^G \\mathbf{H}_{C_i} \\in \\text{SE}(3)$ of the contact frames, advanced kino-dynamic methods may require a relationship between full kinematics and centroidal dynamics. This is typically achieved through the _Centroidal Momentum Matrix_ (also known as the _centroidal momentum Jacobian_):\n",
"\n",
"$$\n",
"{}_G \\mathbf{h} = J_\\text{CMM}(\\mathbf{q}) \\, \\boldsymbol{\\nu}\n",
".\n",
"$$\n",
"\n",
"JaxSim offers APIs to compute all these quantities (and many more) in the `jaxsim.api.com` package."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rrSfxp8lh9YZ"
},
"outputs": [],
"source": [
"# Number of contact points.\n",
"n_cp = len(model.kin_dyn_parameters.contact_parameters.body)\n",
"print(\"Number of contact points:\", n_cp, \"\\n\")\n",
"\n",
"# Compute the centroidal momentum.\n",
"J_CMM = js.com.centroidal_momentum_jacobian(model=model, data=data)\n",
"G_h = J_CMM @ ν\n",
"print(f\"G_h: shape={G_h.shape}\")\n",
"print(f\"J_CMM: shape={J_CMM.shape}\")\n",
"\n",
"# The following relation should hold.\n",
"assert jnp.allclose(G_h, js.com.centroidal_momentum(model=model, data=data))\n",
"\n",
"# If we consider all contact points of the model as active\n",
"# (discourages since they might be too many), the 6D transforms of\n",
"# collidable points can be computed as follows:\n",
"W_H_C = js.contact.transforms(model=model, data=data)\n",
"\n",
"# Compute the pose of the G frame.\n",
"W_p_CoM = js.com.com_position(model=model, data=data)\n",
"G_H_W = jaxsim.math.Transform.inverse(jnp.eye(4).at[0:3, 3].set(W_p_CoM))\n",
"\n",
"# Convert from SE(3) to the transforms for 6D forces.\n",
"G_Xf_C = jax.vmap(\n",
" lambda W_H_Ci: jaxsim.math.Adjoint.from_transform(\n",
" transform=G_H_W @ W_H_Ci, inverse=True\n",
" )\n",
")(W_H_C)\n",
"print(f\"G_Xf_C: shape={G_Xf_C.shape}\")\n",
"\n",
"# Let's create random 3D linear forces applied to the contact points.\n",
"C_fl = jax.random.uniform(\n",
" minval=-10.0,\n",
" maxval=10.0,\n",
" shape=(n_cp, 3),\n",
" key=jax.random.PRNGKey(0),\n",
")\n",
"\n",
"# Compute the 3D gravity vector and the total mass of the robot.\n",
"m = js.model.total_mass(model=model)\n",
"\n",
"# The centroidal dynamics can be computed as follows.\n",
"G_ḣ = 0\n",
"G_ḣ += m * jnp.hstack([0, 0, model.gravity, 0, 0, 0])\n",
"G_ḣ += jnp.einsum(\"c66,c6->6\", G_Xf_C, jnp.hstack([C_fl, jnp.zeros_like(C_fl)]))\n",
"print(f\"G_ḣ: shape={G_ḣ.shape}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Ot6HePB_twaE",
"outputId": "02a6abae-257e-45ee-e9de-6a607cdbeb9a"
},
"source": [
"## Contact Frames\n",
"\n",
"Many control and planning applications require projecting the floating-base dynamics into the contact space or computing quantities related to active contact points, such as enforcing holonomic constraints.\n",
"\n",
"The underlying theory for these applications becomes clearer in a mixed representation. Specifically, the position, linear velocity, and linear acceleration of contact points in their corresponding mixed frame align with the numerical derivatives of their coordinate vectors.\n",
"\n",
"Key methodologies in this area may involve the Delassus matrix:\n",
"\n",
"$$\n",
"\\Psi(\\mathbf{q}) = J_{W,C}(\\mathbf{q}) \\, M(\\mathbf{q})^{-1} \\, J_{W,C}^T(\\mathbf{q})\n",
"$$\n",
"\n",
"or the linear acceleration of a contact point:\n",
"\n",
"$$\n",
"{}^W \\ddot{\\mathbf{p}}_C = \\frac{\\text{d} (J^l_{W,C} \\boldsymbol{\\nu})}{\\text{d}t}\n",
"= \\dot{J}^l_{W,C} \\boldsymbol{\\nu} + J^l_{W,C} \\dot{\\boldsymbol{\\nu}}\n",
".\n",
"$$\n",
"\n",
"JaxSim offers APIs to compute all these quantities (and many more) in the `jaxsim.api.contact` package."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LITRC3STliKR"
},
"outputs": [],
"source": [
"with (\n",
" data.switch_velocity_representation(VelRepr.Mixed),\n",
" references.switch_velocity_representation(VelRepr.Mixed),\n",
"):\n",
"\n",
" # Compute the mixed generalized velocity.\n",
" BW_ν = data.generalized_velocity\n",
"\n",
" # Compute the mixed generalized acceleration.\n",
" BW_ν̇ = jnp.hstack(\n",
" js.model.forward_dynamics(\n",
" model=model,\n",
" data=data,\n",
" link_forces=references.link_forces(model=model, data=data),\n",
" joint_forces=references.joint_force_references(model=model),\n",
" )\n",
" )\n",
"\n",
" # Compute the mass matrix in mixed representation.\n",
" BW_M = js.model.free_floating_mass_matrix(model=model, data=data)\n",
"\n",
" # Compute the contact Jacobian and its derivative.\n",
" Jl_WC = js.contact.jacobian(model=model, data=data)[:, 0:3, :]\n",
" J̇l_WC = js.contact.jacobian_derivative(model=model, data=data)[:, 0:3, :]\n",
"\n",
"# Compute the Delassus matrix.\n",
"Ψ = jnp.vstack(Jl_WC) @ jnp.linalg.lstsq(BW_M, jnp.vstack(Jl_WC).T)[0]\n",
"print(f\"Ψ: shape={Ψ.shape}\")\n",
"\n",
"# Compute the transforms of the mixed frames implicitly associated\n",
"# to each collidable point.\n",
"W_H_C = js.contact.transforms(model=model, data=data)\n",
"print(f\"W_H_C: shape={W_H_C.shape}\")\n",
"\n",
"# Compute the linear velocity of the collidable points.\n",
"with data.switch_velocity_representation(VelRepr.Mixed):\n",
" W_ṗ_B = js.contact.collidable_point_velocities(model=model, data=data)[:, 0:3]\n",
" print(f\"W_ṗ_B: shape={W_ṗ_B.shape}\")\n",
"\n",
"# Compute the linear acceleration of the collidable points.\n",
"W_p̈_C = 0\n",
"W_p̈_C += jnp.einsum(\"c3g,g->c3\", J̇l_WC, BW_ν)\n",
"W_p̈_C += jnp.einsum(\"c3g,g->c3\", Jl_WC, BW_ν̇)\n",
"print(f\"W_p̈_C: shape={W_p̈_C.shape}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LITRC3STliKR"
},
"source": [
"## Conclusions\n",
"\n",
"This notebook provided an overview of the main APIs in JaxSim for its use as a multibody dynamics library. Here are a few key points to remember:\n",
"\n",
"- Explore all the modules in the `jaxsim.api` package to discover the full range of APIs available. Many more functionalities exist beyond what was covered in this notebook.\n",
"- All APIs follow a functional approach, consistent with the JAX programming style.\n",
"- This functional design allows for easy application of `jax.vmap` to execute functions in parallel on hardware accelerators.\n",
"- Since the entire multibody dynamics library is built with JAX, it natively supports `jax.grad`, `jax.jacfwd`, and `jax.jacrev` transformations, enabling automatic differentiation through complex logic without additional effort.\n",
"\n",
"Have fun!"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuClass": "premium",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "comodo_jaxsim",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: examples/jaxsim_as_physics_engine.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "H-WgcgGQaTG7"
},
"source": [
"# JaxSim as a hardware-accelerated parallel physics engine\n",
"\n",
"This notebook shows how to use the key APIs to load a robot model and simulate multiple trajectories simultaneously.\n",
"\n",
"\n",
" \n",
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SgOSnrSscEkt"
},
"source": [
"## Prepare the environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fdqvAqMDaTG9"
},
"outputs": [],
"source": [
"# @title Imports and setup\n",
"import os\n",
"import sys\n",
"from IPython.display import clear_output\n",
"\n",
"IS_COLAB = \"google.colab\" in sys.modules\n",
"\n",
"# Install JAX and Gazebo\n",
"if IS_COLAB:\n",
" !{sys.executable} -m pip install --pre -qU jaxsim\n",
" !apt install -qq lsb-release wget gnupg\n",
" !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n",
" !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n",
" !apt -qq update\n",
" !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n",
"\n",
" clear_output()\n",
"\n",
"# Set environment variable to avoid GPU out of memory errors\n",
"%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n",
"\n",
"\n",
"# ================\n",
"# Notebook imports\n",
"# ================\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jaxsim.api as js\n",
"from jaxsim import logging\n",
"import pathlib\n",
"\n",
"logging.set_logging_level(logging.LoggingLevel.WARNING)\n",
"print(f\"Running on {jax.devices()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NqjuZKvOaTG_"
},
"source": [
"## Prepare the simulation\n",
"\n",
"JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. In this example, we will load the [ergoCub][ergocub] model urdf.\n",
"\n",
"[sdformat]: http://sdformat.org/\n",
"[urdf]: http://wiki.ros.org/urdf/\n",
"[ergocub]: https://ergocub.eu/\n",
"[rod]: https://github.com/gbionics/rod\n",
"\n",
"### Create the model and its data\n",
" To define a simulation we need two main objects:\n",
"\n",
"- `model`: an object that defines the dynamics of the system.\n",
"- `data`: an object that contains the state of the system.\n",
"\n",
"\n",
"The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\n",
"To see the advanced usage, check the advanced example, where you will see how to pass explicitly an integrator class and state to the `model` object and how to change the contact model."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create the model "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "etQ577cFaTHA"
},
"outputs": [],
"source": [
"# Create the JaxSim model.\n",
"try:\n",
" os.environ[\"ROBOT_DESCRIPTION_COMMIT\"] = \"v0.7.7\"\n",
"\n",
" import robot_descriptions.ergocub_description\n",
"\n",
"finally:\n",
" _ = os.environ.pop(\"ROBOT_DESCRIPTION_COMMIT\", None)\n",
"\n",
"model_description_path = pathlib.Path(\n",
" robot_descriptions.ergocub_description.URDF_PATH.replace(\n",
" \"ergoCubSN002\", \"ergoCubSN001\"\n",
" )\n",
")\n",
"\n",
"clear_output()\n",
"\n",
"full_model = js.model.JaxSimModel.build_from_model_description(\n",
" model_description=model_description_path,\n",
" time_step=0.0001,\n",
" is_urdf=True\n",
")\n",
"\n",
"joints_list = tuple(('l_shoulder_pitch', 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow',\n",
" 'r_shoulder_pitch', 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow',\n",
" 'l_hip_pitch', 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',\n",
" 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll'))\n",
"\n",
"model = js.model.reduce(\n",
" model=full_model,\n",
" considered_joints=joints_list\n",
")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create the data object \n",
"\n",
"The data object is never changed by reference. Anytime you call a method aimed at modifying data, like `reset_base_position`, a new data object will be returned with the updated attributes while the original data will not be changed."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create the data of a single model.\n",
"data = js.data.JaxSimModelData.build(model=model, base_position=jnp.array([0.0, 0.0, 1.0]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Simulation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a random JAX key.\n",
"\n",
"key = jax.random.PRNGKey(seed=0)\n",
"\n",
"# Initialize the simulated time.\n",
"T = jnp.arange(start=0, stop=0.3, step=model.time_step)\n",
"\n",
"# Simulate\n",
"for _t in T:\n",
" data = js.model.step(\n",
" model=model,\n",
" data=data,\n",
" link_forces=None,\n",
" joint_force_references=None,\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Vectorized simulation \n",
"\n",
"We will now vectorize the simulation on batched data using `jax.vmap`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# first we have to vmap the function\n",
"\n",
"import functools\n",
"from typing import Any\n",
"\n",
"\n",
"@jax.jit\n",
"def step_single(\n",
" model: js.model.JaxSimModel,\n",
" data: js.data.JaxSimModelData,\n",
") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n",
"\n",
" # Close step over static arguments.\n",
" return js.model.step(\n",
" model=model,\n",
" data=data,\n",
" link_forces=None,\n",
" joint_force_references=None,\n",
" )\n",
"\n",
"\n",
"@jax.jit\n",
"@functools.partial(jax.vmap, in_axes=(None, 0))\n",
"def step_parallel(\n",
" model: js.model.JaxSimModel,\n",
" data: js.data.JaxSimModelData,\n",
") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n",
"\n",
" return step_single(\n",
" model=model, data=data\n",
" )\n",
"\n",
"\n",
"# Then we have to create the vector of initial state\n",
"batch_size = 5\n",
"data_batch_t0 = jax.vmap(\n",
" lambda pos: js.data.JaxSimModelData.build(model=model, base_position=pos))(jnp.tile(jnp.array([0.0, 0.0, 1.0]), (batch_size, 1)))\n",
"\n",
"data = data_batch_t0\n",
"for _t in T:\n",
" data = step_parallel(model, data)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuClass": "premium",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "jaxsim",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.1"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: examples/jaxsim_as_physics_engine_advanced.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "H-WgcgGQaTG7"
},
"source": [
"# JaxSim as a hardware-accelerated parallel physics engine-advanced usage\n",
"\n",
"JaxSim is developed to optimize synthetic data generation by sampling trajectories using hardware accelerators such as GPUs and TPUs.\n",
"\n",
"In this notebook, you'll learn how to use the key APIs to load a simple robot model (a sphere) and simulate multiple trajectories in parallel on GPUs.\n",
"\n",
"\n",
" \n",
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SgOSnrSscEkt"
},
"source": [
"## Prepare the environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fdqvAqMDaTG9"
},
"outputs": [],
"source": [
"# @title Imports and setup\n",
"import sys\n",
"from IPython.display import clear_output\n",
"\n",
"IS_COLAB = \"google.colab\" in sys.modules\n",
"\n",
"# Install JAX and Gazebo\n",
"if IS_COLAB:\n",
" !{sys.executable} -m pip install --pre -qU jaxsim[viz]\n",
" !apt install -qq lsb-release wget gnupg\n",
" !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n",
" !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n",
" !apt -qq update\n",
" !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n",
"\n",
" clear_output()\n",
"\n",
"# Set environment variable to avoid GPU out of memory errors\n",
"%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n",
"\n",
"# ================\n",
"# Notebook imports\n",
"# ================\n",
"\n",
"import os\n",
"\n",
"if sys.platform == 'darwin':\n",
" os.environ[\"MUJOCO_GL\"] = \"glfw\"\n",
"else:\n",
" os.environ[\"MUJOCO_GL\"] = \"egl\"\n",
"\n",
"import jax\n",
"\n",
"import jax.numpy as jnp\n",
"import jaxsim.api as js\n",
"import rod\n",
"from jaxsim import logging\n",
"from rod.builder.primitives import SphereBuilder\n",
"\n",
"logging.set_logging_level(logging.LoggingLevel.WARNING)\n",
"print(f\"Running on {jax.devices()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QtCCUhdpdGFH"
},
"source": [
"## Prepare the simulation\n",
"\n",
"JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. This is done using the [`gbionics/rod`][rod] library, which processes these formats.\n",
"\n",
"The `rod` library also allows creating in-memory models that can be serialized to SDF or URDF. We'll use this functionality to build a sphere model, which will later be used to create the JaxSim model.\n",
"\n",
"[sdformat]: http://sdformat.org/\n",
"[urdf]: http://wiki.ros.org/urdf/\n",
"[rod]: https://github.com/gbionics/rod"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "0emoMQhCaTG_"
},
"outputs": [],
"source": [
"# @title Create the model description of a sphere\n",
"\n",
"# Create a SDF model.\n",
"# The builder takes care to compute the right inertia tensor for you.\n",
"rod_sdf = rod.Sdf(\n",
" version=\"1.7\",\n",
" model=SphereBuilder(radius=0.10, mass=1.0, name=\"sphere\")\n",
" .build_model()\n",
" .add_link()\n",
" .add_inertial()\n",
" .add_visual()\n",
" .add_collision()\n",
" .build(),\n",
")\n",
"\n",
"# Rod allows to update the frames w.r.t. the poses are expressed.\n",
"rod_sdf.model.switch_frame_convention(\n",
" frame_convention=rod.FrameConvention.Urdf, explicit_frames=True\n",
")\n",
"\n",
"# Serialize the model to a SDF string.\n",
"model_sdf_string = rod_sdf.serialize(pretty=True)\n",
"print(model_sdf_string)\n",
"\n",
"# JaxSim currently only supports collisions between points attached to bodies\n",
"# and a ground surface modeled as a heightmap sampled from a smooth function.\n",
"# While this approach is universal as it applies to generic meshes, the number\n",
"# of considered points greatly affects the performance. Spheres, by default,\n",
"# are discretized with 250 points. It's too much for this simple example.\n",
"# This number can be decreased with the following environment variable.\n",
"os.environ[\"JAXSIM_COLLISION_SPHERE_POINTS\"] = \"50\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NqjuZKvOaTG_"
},
"source": [
"### Create the model and its data\n",
"\n",
"JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\n",
"\n",
"- `model`: an object that defines the dynamics of the system.\n",
"- `data`: an object that contains the state of the system.\n",
"- `integrator` *(Optional)*: an object that defines the integration method.\n",
"- `integrator_metadata` *(Optional)*: an object that contains the state of the integrator.\n",
"\n",
"The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\n",
"In this example, we will explicitly pass an integrator class to the `model` object and we will use the default `SoftContacts` contact model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "etQ577cFaTHA"
},
"outputs": [],
"source": [
"# Create the JaxSim model.\n",
"# This is shared among all the parallel instances.\n",
"model = js.model.JaxSimModel.build_from_model_description(\n",
" model_description=model_sdf_string,\n",
" time_step=0.001,\n",
")\n",
"\n",
"# Create the data of a single model.\n",
"# We will create a vectorized instance later.\n",
"data_single = js.data.JaxSimModelData.zero(model=model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "o86Teq5piVGj"
},
"outputs": [],
"source": [
"# Initialize the simulated time.\n",
"T = jnp.arange(start=0, stop=1.0, step=model.time_step)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "V6IeD2B3m4F0"
},
"source": [
"## Sample a batch of trajectories in parallel\n",
"\n",
"With the provided resources, you can step through an open-loop trajectory on a single model using `jaxsim.api.model.step`.\n",
"\n",
"In this notebook, we'll focus on running parallel steps. We'll use JAX's automatic vectorization to apply the step function to batched data.\n",
"\n",
"Note that these parallel simulations are independent — models don't interact, so there's no need to avoid initial collisions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vtEn0aIzr_2j"
},
"outputs": [],
"source": [
"# @title Generate batched initial data\n",
"\n",
"# Create a random JAX key.\n",
"key = jax.random.PRNGKey(seed=0)\n",
"\n",
"# Split subkeys for sampling random initial data.\n",
"batch_size = 9\n",
"row_length = int(jnp.sqrt(batch_size))\n",
"row_dist = 0.3 * row_length\n",
"key, *subkeys = jax.random.split(key=key, num=batch_size + 1)\n",
"\n",
"# Create the batched data by sampling the height from [0.5, 0.6] meters.\n",
"data_batch_t0 = jax.vmap(\n",
" lambda key: js.data.random_model_data(\n",
" model=model,\n",
" key=key,\n",
" base_pos_bounds=([0, 0, 0.3], [0, 0, 1.2]),\n",
" base_vel_lin_bounds=(0, 0),\n",
" base_vel_ang_bounds=(0, 0),\n",
" )\n",
")(jnp.vstack(subkeys))\n",
"\n",
"x, y = jnp.meshgrid(\n",
" jnp.linspace(-row_dist, row_dist, num=row_length),\n",
" jnp.linspace(-row_dist, row_dist, num=row_length),\n",
")\n",
"xy_coordinate = jnp.stack([x.flatten(), y.flatten()], axis=-1)\n",
"\n",
"# Reset the x and y position to a grid.\n",
"data_batch_t0 = data_batch_t0.replace(\n",
" model=model,\n",
" base_position=data_batch_t0.base_position.at[:, :2].set(xy_coordinate),\n",
")\n",
"\n",
"print(\"W_p_B(t0)=\\n\", data_batch_t0.base_position[0:10])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0tQPfsl6uxHm"
},
"outputs": [],
"source": [
"# @title Create parallel step function\n",
"\n",
"import functools\n",
"from typing import Any\n",
"\n",
"\n",
"@jax.jit\n",
"def step_single(\n",
" model: js.model.JaxSimModel,\n",
" data: js.data.JaxSimModelData,\n",
") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n",
"\n",
" # Close step over static arguments.\n",
" return js.model.step(\n",
" model=model,\n",
" data=data,\n",
" link_forces=None,\n",
" joint_force_references=None,\n",
" )\n",
"\n",
"\n",
"@jax.jit\n",
"@functools.partial(jax.vmap, in_axes=(None, 0))\n",
"def step_parallel(\n",
" model: js.model.JaxSimModel,\n",
" data: js.data.JaxSimModelData,\n",
") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n",
"\n",
" return step_single(\n",
" model=model, data=data\n",
" )\n",
"\n",
"\n",
"# The first run will be slow since JAX needs to JIT-compile the functions.\n",
"_ = step_single(model, data_single)\n",
"_ = step_parallel(model, data_batch_t0)\n",
"\n",
"# Benchmark the execution of a single step.\n",
"print(\"\\nSingle simulation step:\")\n",
"%timeit step_single(model, data_single)\n",
"\n",
"# On hardware accelerators, there's a range of batch_size values where\n",
"# increasing the number of parallel instances doesn't affect computation time.\n",
"# This range depends on the GPU/TPU specifications.\n",
"print(f\"\\nParallel simulation steps (batch_size={batch_size} on {jax.devices()[0]}):\")\n",
"%timeit step_parallel(model, data_batch_t0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VNwzT2JQ1n15"
},
"outputs": [],
"source": [
"# @title Run parallel simulation\n",
"\n",
"data = data_batch_t0\n",
"data_trajectory_list = []\n",
"\n",
"for _ in T:\n",
"\n",
" data = step_parallel(model, data)\n",
" data_trajectory_list.append(data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y6n720Cr3G44"
},
"source": [
"## Visualize trajectory"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BLPODyKr3Lyg"
},
"outputs": [],
"source": [
"# Convert a list of PyTrees to a batched PyTree.\n",
"# This operation is called 'tree transpose' in JAX.\n",
"data_trajectory = jax.tree.map(lambda *leafs: jnp.stack(leafs), *data_trajectory_list)\n",
"\n",
"print(f\"W_p_B: shape={data_trajectory.base_position.shape}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-jxJXy5r3RMt"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"plt.plot(T, data_trajectory.base_position[:, :, 2])\n",
"plt.grid(True)\n",
"plt.xlabel(\"Time [s]\")\n",
"plt.ylabel(\"Height [m]\")\n",
"plt.title(\"Height trajectory of the sphere\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import jaxsim.mujoco\n",
"\n",
"mjcf_string, assets = jaxsim.mujoco.ModelToMjcf.convert(\n",
" model.built_from,\n",
" cameras=jaxsim.mujoco.loaders.MujocoCamera.build_from_target_view(\n",
" camera_name=\"sphere_cam\",\n",
" lookat=[0, 0, 0.3],\n",
" distance=4,\n",
" azimuth=150,\n",
" elevation=-10,\n",
" ),\n",
")\n",
"\n",
"# Create a helper for each parallel instance.\n",
"mj_model_helpers = [\n",
" jaxsim.mujoco.MujocoModelHelper.build_from_xml(\n",
" mjcf_description=mjcf_string, assets=assets\n",
" )\n",
" for _ in range(batch_size)\n",
"]\n",
"\n",
"# Create the video recorder.\n",
"recorder = jaxsim.mujoco.MujocoVideoRecorder(\n",
" model=mj_model_helpers[0].model,\n",
" data=[helper.data for helper in mj_model_helpers],\n",
" fps=int(1 / model.time_step),\n",
" width=320 * 2,\n",
" height=240 * 2,\n",
")\n",
"\n",
"for data_t in data_trajectory_list:\n",
"\n",
" for helper, base_position, base_quaternion, joint_position in zip(\n",
" mj_model_helpers,\n",
" data_t.base_position,\n",
" data_t.base_orientation,\n",
" data_t.joint_positions,\n",
" strict=True,\n",
" ):\n",
" helper.set_base_position(position=base_position)\n",
" helper.set_base_orientation(orientation=base_quaternion)\n",
"\n",
" if model.dofs() > 0:\n",
" helper.set_joint_positions(\n",
" positions=joint_position, joint_names=model.joint_names()\n",
" )\n",
"\n",
" # Record a new video frame.\n",
" recorder.record_frame(camera_name=\"sphere_cam\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import mediapy as media\n",
"\n",
"media.show_video(recorder.frames, fps=recorder.fps)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuClass": "premium",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "jaxpypi",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.1"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: examples/jaxsim_for_robot_controllers.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "EhPy6FgiZH4d"
},
"source": [
"# JaxSim for developing closed-loop robot controllers\n",
"\n",
"Originally developed as a **hardware-accelerated physics engine**, JaxSim has expanded its capabilities to become a full-featured **JAX-based multibody dynamics library**.\n",
"\n",
"In this notebook, you'll explore how to combine these two core features. Specifically, you'll learn how to load a robot model and design a model-based controller for closed-loop simulations.\n",
"\n",
"\n",
" \n",
""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vsf1AlxdZH4f"
},
"outputs": [],
"source": [
"# @title Prepare the environment\n",
"from IPython.display import clear_output\n",
"import sys\n",
"\n",
"IS_COLAB = \"google.colab\" in sys.modules\n",
"\n",
"# Install JAX, sdformat, and other notebook dependencies.\n",
"if IS_COLAB:\n",
" !{sys.executable} -m pip install --pre -qU jaxsim[viz]\n",
" !apt install -qq lsb-release wget gnupg\n",
" !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n",
" !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n",
" !apt -qq update\n",
" !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n",
"\n",
" clear_output()\n",
"\n",
"# ================\n",
"# Notebook imports\n",
"# ================\n",
"\n",
"import os\n",
"\n",
"if sys.platform == 'darwin':\n",
" os.environ[\"MUJOCO_GL\"] = \"glfw\"\n",
"else:\n",
" os.environ[\"MUJOCO_GL\"] = \"egl\"\n",
"\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jaxsim.mujoco\n",
"from jaxsim import logging\n",
"\n",
"logging.set_logging_level(logging.LoggingLevel.WARNING)\n",
"print(f\"Running on {jax.devices()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kN-b9nOsZH4g"
},
"source": [
"We will use a simple cartpole model for this example. The cartpole model is a 2D model with a cart that can move horizontally and a pole that can rotate around the cart. The state of the cartpole is given by the position of the cart, the angle of the pole, the velocity of the cart, and the angular velocity of the pole. The control input is the horizontal force applied to the cart."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5aLqrZDqR5LA"
},
"source": [
"## Prepare the simulation\n",
"\n",
"JaxSim supports loading robot models from both [SDF][sdformat] and [URDF][urdf] files, utilizing the [`gbionics/rod`][rod] library for processing these formats.\n",
"\n",
"The `rod` library library can read URDF files and validates them internally using [`gazebosim/sdformat`][sdformat_github]. In this example, we'll load a cart-pole model, which will be used to create the JaxSim simulation model.\n",
"\n",
"[sdformat]: http://sdformat.org/\n",
"[urdf]: http://wiki.ros.org/urdf/\n",
"[rod]: https://github.com/gbionics/rod\n",
"[sdformat_github]: https://github.com/gazebosim/sdformat"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.path.abspath(\"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PZM7hEvFZH4h"
},
"outputs": [],
"source": [
"# @title Load the URDF model\n",
"import pathlib\n",
"import urllib\n",
"\n",
"# Retrieve the file\n",
"url = \"https://raw.githubusercontent.com/gbionics/jaxsim/refs/heads/main/examples/assets/cartpole.urdf\"\n",
"model_path, _ = urllib.request.urlretrieve(url)\n",
"model_urdf_string = pathlib.Path(model_path).read_text()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "M5XsKehvZH4j"
},
"outputs": [],
"source": [
"# @title Create the model and its data\n",
"\n",
"import jaxsim.api as js\n",
"\n",
"# Create the model from the model description.\n",
"model = js.model.JaxSimModel.build_from_model_description(\n",
" model_description=model_urdf_string,\n",
" time_step=0.010,\n",
")\n",
"\n",
"# Create the data storing the simulation state.\n",
"data_zero = js.data.JaxSimModelData.zero(model=model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jk9csR5ETgn1"
},
"outputs": [],
"source": [
"# @title Define simulation parameters\n",
"\n",
"# Initialize the simulated time.\n",
"T = jnp.arange(start=0, stop=5.0, step=model.time_step)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bo6Ke5nAWL-S"
},
"source": [
"## Prepare the MuJoCo renderer\n",
"\n",
"For visualization purpose, we use the passive viewer of the MuJoCo simulator. It allows to either open an interactive windows when used locally or record a video when used in notebooks."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "j1_I2i5TZH4n"
},
"outputs": [],
"source": [
"# Create the MJCF resources from the URDF.\n",
"mjcf_string, assets = jaxsim.mujoco.UrdfToMjcf.convert(\n",
" urdf=model.built_from,\n",
" # Create the camera used by the recorder.\n",
" cameras=jaxsim.mujoco.loaders.MujocoCamera.build_from_target_view(\n",
" camera_name=\"cartpole_camera\",\n",
" lookat=js.link.com_position(\n",
" model=model,\n",
" data=data_zero,\n",
" link_index=js.link.name_to_idx(model=model, link_name=\"cart\"),\n",
" in_link_frame=False,\n",
" ),\n",
" distance=3,\n",
" azimuth=150,\n",
" elevation=-10,\n",
" ),\n",
")\n",
"\n",
"# Create a helper to operate on the MuJoCo model and data.\n",
"mj_model_helper = jaxsim.mujoco.MujocoModelHelper.build_from_xml(\n",
" mjcf_description=mjcf_string, assets=assets\n",
")\n",
"\n",
"# Create the video recorder.\n",
"recorder = jaxsim.mujoco.MujocoVideoRecorder(\n",
" model=mj_model_helper.model,\n",
" data=mj_model_helper.data,\n",
" fps=int(1 / model.time_step),\n",
" width=320 * 2,\n",
" height=240 * 2,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DpRvvGujZH4o"
},
"source": [
"## Open-loop simulation\n",
"\n",
"Now, let's run a simulation to demonstrate the open-loop dynamics of the system."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gSWzcsKWZH4p"
},
"outputs": [],
"source": [
"import mediapy as media\n",
"\n",
"\n",
"# Create a random joint position.\n",
"# For a random full state, you can use jaxsim.api.data.random_model_data.\n",
"random_joint_positions = jax.random.uniform(\n",
" minval=-1.0,\n",
" maxval=1.0,\n",
" shape=(model.dofs(),),\n",
" key=jax.random.PRNGKey(0),\n",
")\n",
"\n",
"# Reset the state to the random joint positions.\n",
"data = js.data.JaxSimModelData.build(model=model, joint_positions=random_joint_positions)\n",
"\n",
"for _ in T:\n",
"\n",
" # Step the JaxSim simulation.\n",
" data = js.model.step(\n",
" model=model,\n",
" data=data,\n",
" joint_force_references=None,\n",
" link_forces=None,\n",
" )\n",
"\n",
" # Update the MuJoCo data.\n",
" mj_model_helper.set_joint_positions(\n",
" positions=data.joint_positions, joint_names=model.joint_names()\n",
" )\n",
"\n",
" # Record a new video frame.\n",
" recorder.record_frame(camera_name=\"cartpole_camera\")\n",
"\n",
"\n",
"# Play the video.\n",
"media.show_video(recorder.frames, fps=recorder.fps)\n",
"recorder.frames = []"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j1rguK3UZH4p"
},
"source": [
"## Closed-loop simulation\n",
"\n",
"Next, let's design a simple computed torque controller. The equations of motion for the cart-pole system are given by:\n",
"\n",
"$$\n",
"M_{ss}(\\mathbf{s}) \\, \\ddot{\\mathbf{s}} + \\mathbf{h}_s(\\mathbf{s}, \\dot{\\mathbf{s}}) = \\boldsymbol{\\tau}\n",
",\n",
"$$\n",
"\n",
"where:\n",
"\n",
"- $\\mathbf{s} \\in \\mathbb{R}^n$ are the joint positions.\n",
"- $\\dot{\\mathbf{s}} \\in \\mathbb{R}^n$ are the joint velocities.\n",
"- $\\ddot{\\mathbf{s}} \\in \\mathbb{R}^n$ are the joint accelerations.\n",
"- $\\boldsymbol{\\tau} \\in \\mathbb{R}^n$ are the joint torques.\n",
"- $M_{ss} \\in \\mathbb{R}^{n \\times n}$ is the mass matrix.\n",
"- $\\mathbf{h}_s \\in \\mathbb{R}^n$ is the vector of bias forces.\n",
"\n",
"JaxSim computes these quantities for floating-base systems, so we specifically focus on the joint-related portions by marking them with subscripts.\n",
"\n",
"Since no external forces or joint friction are present, we can extend a PD controller with a feed-forward term that includes gravity compensation:\n",
"\n",
"$$\n",
"\\begin{cases}\n",
"\\boldsymbol{\\tau} &= M_{ss} \\, \\ddot{\\mathbf{s}}^* + \\mathbf{h}_s \\\\\n",
"\\ddot{\\mathbf{s}}^* &= \\ddot{\\mathbf{s}}^\\text{des} - k_p(\\mathbf{s} - \\mathbf{s}^{\\text{des}}) - k_d(\\mathbf{s}^{\\text{des}} - \\dot{\\mathbf{s}}^{\\text{des}})\n",
"\\end{cases}\n",
"\\quad\n",
",\n",
"$$\n",
"\n",
"where $\\tilde{\\mathbf{s}} = \\left(\\mathbf{s} - \\mathbf{s}^\\text{des}\\right)$ is the joint position error.\n",
"\n",
"With this control law, the closed-loop system dynamics simplifies to:\n",
"\n",
"$$\n",
"\\ddot{\\tilde{\\mathbf{s}}} = -k_p \\tilde{\\mathbf{s}} - k_d \\dot{\\tilde{\\mathbf{s}}}\n",
",\n",
"$$\n",
"\n",
"which converges asymptotically to zero, ensuring stability."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rfMTCMyGZH4q"
},
"outputs": [],
"source": [
"# @title Create the computed torque controller\n",
"\n",
"# Define the PD gains\n",
"kp = 10.0\n",
"kd = 6.0\n",
"\n",
"\n",
"def computed_torque_controller(\n",
" data: js.data.JaxSimModelData,\n",
" s_des: jax.Array,\n",
" s_dot_des: jax.Array,\n",
") -> jax.Array:\n",
"\n",
" # Compute the gravity compensation term.\n",
" hs = js.model.free_floating_bias_forces(model=model, data=data)[6:]\n",
"\n",
" # Compute the joint-related portion of the floating-base mass matrix.\n",
" Mss = js.model.free_floating_mass_matrix(model=model, data=data)[6:, 6:]\n",
"\n",
" # Get the current joint positions and velocities.\n",
" s = data.joint_positions\n",
" ṡ = data.joint_velocities\n",
"\n",
" # Compute the actuated joint torques.\n",
" s_star = -kp * (s - s_des) - kd * (ṡ - s_dot_des)\n",
" τ = Mss @ s_star + hs\n",
"\n",
" return τ"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ERAUisywZH4q"
},
"source": [
"Now, we can use the `pd_controller` function to compute the torque to apply to the cartpole. Our aim is to stabilize the cartpole in the upright position, so we set the desired position `q_d` to 0 and the desired velocity `q_dot_d` to 0."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8YmDdGDVZH4q"
},
"outputs": [],
"source": [
"# @title Run the simulation\n",
"\n",
"# Initialize the data.\n",
"\n",
"# Set the joint positions.\n",
"data = js.data.JaxSimModelData.build(model=model, joint_positions=jnp.array([-0.25, jnp.deg2rad(160)]), joint_velocities=jnp.array([3.00, jnp.deg2rad(10) / model.time_step]))\n",
"\n",
"for _ in T:\n",
"\n",
" # Get the actuated torques from the computed torque controller.\n",
" τ = computed_torque_controller(\n",
" data=data,\n",
" s_des=jnp.array([0.0, 0.0]),\n",
" s_dot_des=jnp.array([0.0, 0.0]),\n",
" )\n",
"\n",
" # Step the JaxSim simulation.\n",
" data = js.model.step(\n",
" model=model,\n",
" data=data,\n",
" joint_force_references=τ,\n",
" )\n",
"\n",
" # Update the MuJoCo data.\n",
" mj_model_helper.set_joint_positions(\n",
" positions=data.joint_positions, joint_names=model.joint_names()\n",
" )\n",
"\n",
" # Record a new video frame.\n",
" recorder.record_frame(camera_name=\"cartpole_camera\")\n",
"\n",
"media.show_video(recorder.frames, fps=recorder.fps)\n",
"recorder.frames = []"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sZ76QqeWeMQz"
},
"source": [
"## Conclusions\n",
"\n",
"In this notebook, we explored how to use JaxSim for developing a closed-loop controller for a robot model. Key takeaways include:\n",
"\n",
"- We performed an open-loop simulation to understand the dynamics of the system without control.\n",
"- We implemented a computed torque controller with PD feedback and a feed-forward gravity compensation term, enabling the stabilization of the system by controlling joint torques.\n",
"- The closed-loop simulation can leverage hardware acceleration on GPUs and TPUs, with the ability to use `jax.vmap` for parallel sampling through automatic vectorization.\n",
"\n",
"JaxSim's closed-loop support can be extended to more advanced, model-based reactive controllers and planners for trajectory optimization. To explore optimization-based methods, consider the following JAX-based projects for hardware-accelerated control and planning:\n",
"\n",
"- [`deepmind/optax`](https://github.com/google-deepmind/optax)\n",
"- [`google/jaxopt`](https://github.com/google/jaxopt)\n",
"- [`patrick-kidger/lineax`](https://github.com/patrick-kidger/lineax)\n",
"- [`patrick-kidger/optimistix`](https://github.com/patrick-kidger/optimistix)\n",
"- [`kevin-tracy/qpax`](https://github.com/kevin-tracy/qpax)\n",
"\n",
"Additionally, if your controllers or planners require the derivatives of the dynamics with respect to the state or inputs, you can obtain them using automatic differentiation directly through JaxSim's API."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuClass": "premium",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "comodo_jaxsim",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: pyproject.toml
================================================
[project]
name = "jaxsim"
dynamic = ["version"]
requires-python = ">= 3.10"
description = "A differentiable physics engine and multibody dynamics library for control and robot learning."
authors = [
{ name = "Diego Ferigo", email = "dgferigo@gmail.com" },
{ name = "Filippo Luca Ferretti", email = "filippoluca.ferretti@outlook.com" },
]
maintainers = [
{ name = "Filippo Luca Ferretti", email = "filippo.ferretti@outlook.com" },
]
license = "BSD-3-Clause"
license-files = ["LICENSE"]
keywords = [
"physics",
"physics engine",
"jax",
"rigid body dynamics",
"featherstone",
"reinforcement learning",
"robot",
"robotics",
"sdf",
"urdf",
]
classifiers = [
"Development Status :: 4 - Beta",
"Framework :: Robot Framework",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Operating System :: POSIX :: Linux",
"Operating System :: MacOS",
"Operating System :: Microsoft",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: Implementation :: CPython",
"Topic :: Games/Entertainment :: Simulation",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Physics",
"Topic :: Software Development",
]
dependencies = [
"coloredlogs",
"jax >= 0.4.34",
"jaxlib >= 0.4.34",
"jaxlie >= 1.3.0",
"jax_dataclasses >= 1.4.0",
"pptree",
"optax >= 0.2.3",
"qpax",
"rod >= 0.4.1",
"typing_extensions ; python_version < '3.12'",
"trimesh",
]
[project.optional-dependencies]
testing = [
"chex >= 0.1.91",
"idyntree >= 12.2.1",
"pytest >=6.0",
"pytest-benchmark",
"pytest-icdiff",
"pytest-xdist",
"robot-descriptions >= 1.16.0",
"icub-models",
]
viz = [
"lxml",
"mediapy",
"mujoco >= 3.0.0",
"scipy >= 1.14.0",
]
all = [
"jaxsim[testing,viz]",
]
[project.readme]
file = "README.md"
content-type = "text/markdown"
[project.urls]
Changelog = "https://github.com/gbionics/jaxsim/releases"
Documentation = "https://jaxsim.readthedocs.io"
Source = "https://github.com/gbionics/jaxsim"
Tracker = "https://github.com/gbionics/jaxsim/issues"
# ===========
# Build tools
# ===========
[build-system]
build-backend = "hatchling.build"
requires = [
"hatchling",
"hatch-vcs",
]
[tool.hatch.version]
source = "vcs"
raw-options = { local_scheme = "dirty-tag" }
[tool.hatch.build.targets.wheel]
packages = ["src/jaxsim"]
[tool.hatch.build.hooks.vcs]
version-file = "src/jaxsim/_version.py"
# =================
# Style and testing
# =================
[tool.black]
line-length = 88
[tool.isort]
multi_line_output = 3
profile = "black"
[tool.pytest.ini_options]
addopts = "-rsxX -v --strict-markers --benchmark-skip --benchmark-warmup=ON"
minversion = "6.0"
testpaths = [
"tests",
]
# ==================
# Ruff configuration
# ==================
[tool.ruff]
exclude = [
".git",
".pixi",
".pytest_cache",
".ruff_cache",
".idea",
".vscode",
".devcontainer",
"__pycache__",
]
preview = true
[tool.ruff.lint]
# https://docs.astral.sh/ruff/rules/
select = [
"B",
"D",
"E",
"F",
"I",
"W",
"RUF",
"UP",
"YTT",
]
ignore = [
"B008", # Function call in default argument
"B024", # Abstract base class without abstract methods
"D100", # Missing docstring in public module
"D104", # Missing docstring in public package
"D105", # Missing docstring in magic method
"D200", # One-line docstring should fit on one line with quotes
"D202", # No blank lines allowed after function docstring
"D203", # Incorrect blank line before class
"D205", # 1 blank line required between summary line and description
"D212", # Multi-line docstring summary should start at the first line
"D411", # Missing blank line before section
"D413", # Missing blank line after last section
"E402", # Module level import not at top of file
"E501", # Line too long
"E731", # Do not assign a `lambda` expression, use a `def`
"E741", # Ambiguous variable name
"I001", # Import block is unsorted or unformatted
"RUF003", # Ambiguous unicode character in comment
]
[tool.ruff.lint.per-file-ignores]
# Ignore `E402` (import violations) in all `__init__.py` files
"**/{tests,docs,tools}/*" = ["E402"]
"**/{tests,examples}/*" = ["B007", "D100", "D102", "D103"]
"__init__.py" = ["F401", "RUF067"]
"docs/conf.py" = ["F401"]
"src/jaxsim/exceptions.py" = ["D401"]
"src/jaxsim/logging.py" = ["D101", "D103"]
# ==================
# Pixi configuration
# ==================
[tool.pixi.workspace]
channels = ["conda-forge"]
platforms = ["linux-64", "linux-aarch64", "osx-arm64", "osx-64"]
requires-pixi = ">=0.39.0"
preview = ["pixi-build"]
[tool.pixi.environments]
# We resolve only two groups: cpu and gpu.
# Then, multiple environments can be created from these groups.
default = { features = ["test", "examples"] }
gpu = { features = ["test", "examples", "gpu"] }
# ---------------
# feature.default
# ---------------
# Dependencies from conda-forge.
[tool.pixi.dependencies]
#
# Matching `project.dependencies`.
#
coloredlogs = "*"
jax = "*"
jaxlib = "*"
jaxlie = "*"
jax-dataclasses = "*"
pptree = "*"
optax = "*"
qpax = "*"
rod = ">=0.4.1"
trimesh = "*"
typing_extensions = "*"
#
# Optional dependencies.
#
lxml = "*"
mediapy = "*"
mujoco = "*"
scipy = "*"
#
# Additional dependencies.
#
pip = "*"
hatchling = "*"
hatch-vcs = "*"
# Dependencies from PyPI.
[tool.pixi.pypi-dependencies]
jaxsim = { path = "./", editable = true }
[tool.pixi.pypi-options]
no-build-isolation = ["jaxsim"]
# ------------
# feature.test
# ------------
[tool.pixi.feature.test.tasks]
pipcheck = "pip check"
benchmark = { cmd = "pytest --benchmark-only --benchmark-warmup=ON", depends-on = ["pipcheck"] }
test = { cmd = "pytest", depends-on = ["pipcheck"] }
[tool.pixi.feature.test.dependencies]
black-jupyter = "*"
chex = ">=0.1.91"
idyntree = "*"
isort = "*"
pre-commit = "*"
pytest = "*"
pytest-benchmark = "*"
pytest-icdiff = "*"
pytest-xdist = "*"
robot_descriptions = ">=1.16.0"
# ----------------
# feature.examples
# ----------------
[tool.pixi.feature.examples.tasks]
examples = { cmd = "jupyter notebook ./examples" }
[tool.pixi.feature.examples.dependencies]
notebook = "*"
robot_descriptions = ">=1.16.0"
# -----------
# feature.gpu
# -----------
[tool.pixi.feature.gpu]
platforms = ["linux-64"]
system-requirements = { cuda = "13" }
[tool.pixi.feature.gpu.dependencies]
jaxlib = { version = "*", build = "*cuda*" }
[tool.pixi.feature.gpu.tasks]
test-gpu = { cmd = "pytest --gpu-only", depends-on = ["pipcheck"] }
================================================
FILE: src/jaxsim/__init__.py
================================================
from . import logging
from ._version import __version__
# Follow upstream development in https://github.com/google/jax/pull/13304
def _jnp_options() -> None:
import os
import jax
# Check if running on TPU.
is_tpu = jax.devices()[0].platform == "tpu"
# Check if running on Metal.
is_metal = jax.devices()[0].platform == "METAL"
# Enable by default 64-bit precision to get accurate physics.
# Users can enforce 32-bit precision by setting the following variable to 0.
use_x64 = os.environ.get("JAX_ENABLE_X64", "1") != "0"
# Notify the user if unsupported 64-bit precision was enforced on TPU.
if (is_tpu or is_metal) and use_x64:
msg = f"64-bit precision is not allowed on {jax.devices()[0].platform.upper}. Enforcing 32bit precision."
logging.warning(msg)
use_x64 = False
if is_metal:
logging.warning(
"JAX Metal backend is experimental. Some functionalities may not be available."
)
# Enable 64-bit precision in JAX.
if use_x64:
logging.info("Enabling JAX to use 64-bit precision")
jax.config.update("jax_enable_x64", True)
# Warn about experimental usage of 32-bit precision.
else:
logging.warning(
"Using 32-bit precision in JaxSim is still experimental, please avoid to use variable step integrators."
)
def _np_options() -> None:
import numpy as np
np.set_printoptions(precision=5, suppress=True, linewidth=150, threshold=10_000)
def _is_editable() -> bool:
import importlib.util
import pathlib
import site
# Get the ModuleSpec of jaxsim.
jaxsim_spec = importlib.util.find_spec(name="jaxsim")
# This can be None. If it's None, assume non-editable installation.
if jaxsim_spec.origin is None:
return False
# Get the folder containing the jaxsim package.
jaxsim_package_dir = str(pathlib.Path(jaxsim_spec.origin).parent.parent)
# The installation is editable if the package dir is not in any {site|dist}-packages.
return jaxsim_package_dir not in site.getsitepackages()
def _get_default_logging_level() -> logging.LoggingLevel:
"""
Get the default logging level.
Returns:
The logging level to set.
"""
import os
import sys
# Allow to override the default logging level with an environment variable.
if overriden_logging_level := os.environ.get("JAXSIM_LOGGING_LEVEL"):
try:
return logging.LoggingLevel[overriden_logging_level.upper()]
except KeyError as exc:
msg = "Invalid logging level defined in JAXSIM_LOGGING_LEVEL"
raise RuntimeError(msg) from exc
# If running under a debugger, set the logging level to DEBUG.
if getattr(sys, "gettrace", lambda: None)():
return logging.LoggingLevel.DEBUG
# If not running under a debugger, set the logging level to INFO or WARNING.
# INFO for editable installations, WARNING for non-editable installations.
# This is to avoid too verbose logging in non-editable installations.
return (
logging.LoggingLevel.INFO
if _is_editable() # noqa: F821
else logging.LoggingLevel.WARNING
)
# Configure the logger with the default logging level.
logging.configure(level=_get_default_logging_level())
# Configure JAX.
_jnp_options()
# Initialize the numpy print options.
_np_options()
del _jnp_options
del _np_options
del _get_default_logging_level
del _is_editable
from . import terrain # isort:skip
from . import api, logging, math, rbda
from .api.common import VelRepr
================================================
FILE: src/jaxsim/api/__init__.py
================================================
from . import common # isort:skip
from . import model, data # isort:skip
from . import (
actuation_model,
com,
contact,
frame,
integrators,
joint,
kin_dyn_parameters,
link,
ode,
references,
)
================================================
FILE: src/jaxsim/api/actuation_model.py
================================================
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.typing as jtp
def compute_resultant_torques(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
joint_force_references: jtp.Vector | None = None,
) -> jtp.Vector:
"""
Compute the resultant torques acting on the joints.
Args:
model: The model to consider.
data: The data of the considered model.
joint_force_references: The joint force references to apply.
Returns:
The resultant torques acting on the joints.
"""
# Build joint torques if not provided.
τ_references = (
jnp.atleast_1d(joint_force_references.squeeze())
if joint_force_references is not None
else jnp.zeros_like(data.joint_positions)
).astype(float)
# ====================
# Enforce joint limits
# ====================
τ_position_limit = jnp.zeros_like(τ_references).astype(float)
if model.dofs() > 0:
# Stiffness and damper parameters for the joint position limits.
k_j = jnp.array(
model.kin_dyn_parameters.joint_parameters.position_limit_spring
).astype(float)
d_j = jnp.array(
model.kin_dyn_parameters.joint_parameters.position_limit_damper
).astype(float)
# Compute the joint position limit violations.
lower_violation = jnp.clip(
data.joint_positions
- model.kin_dyn_parameters.joint_parameters.position_limits_min,
max=0.0,
)
upper_violation = jnp.clip(
data.joint_positions
- model.kin_dyn_parameters.joint_parameters.position_limits_max,
min=0.0,
)
# Compute the joint position limit torque.
τ_position_limit -= jnp.diag(k_j) @ (lower_violation + upper_violation)
τ_position_limit -= (
jnp.positive(τ_position_limit) * jnp.diag(d_j) @ data.joint_velocities
)
# ====================
# Joint friction model
# ====================
τ_friction = jnp.zeros_like(τ_references).astype(float)
# Apply joint friction only if enabled in the actuation parameters.
if model.dofs() > 0 and model.actuation_params.enable_friction:
# Static and viscous joint friction parameters
kc = jnp.array(
model.kin_dyn_parameters.joint_parameters.friction_static
).astype(float)
kv = jnp.array(
model.kin_dyn_parameters.joint_parameters.friction_viscous
).astype(float)
# Compute the joint friction torque.
τ_friction = -(
jnp.diag(kc) @ jnp.sign(data.joint_velocities)
+ jnp.diag(kv) @ data.joint_velocities
)
# ===============================
# Compute the total joint forces.
# ===============================
τ_total = τ_references + τ_friction + τ_position_limit
τ_lim = tn_curve_fn(model=model, data=data)
τ_total = jnp.clip(τ_total, -τ_lim, τ_lim)
return τ_total
def tn_curve_fn(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
"""
Compute the torque limits using the tn curve.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The torque limits.
"""
τ_max = model.actuation_params.torque_max # Max torque (Nm)
ω_th = model.actuation_params.omega_th # Threshold speed (rad/s)
ω_max = model.actuation_params.omega_max # Max speed for torque drop-off (rad/s)
abs_vel = jnp.abs(data.joint_velocities)
τ_lim = jnp.where(
abs_vel <= ω_th,
τ_max,
jnp.where(
abs_vel <= ω_max, τ_max * (1 - (abs_vel - ω_th) / (ω_max - ω_th)), 0.0
),
)
return τ_lim
================================================
FILE: src/jaxsim/api/com.py
================================================
import jax
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.math
import jaxsim.typing as jtp
from .common import VelRepr
@jax.jit
@js.common.named_scope
def com_position(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
"""
Compute the position of the center of mass of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The position of the center of mass of the model w.r.t. the world frame.
"""
m = js.model.total_mass(model=model)
W_H_L = data._link_transforms
W_H_B = data._base_transform
B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B)
def B_p̃_LCoM(i) -> jtp.Vector:
m = js.link.mass(model=model, link_index=i)
L_p_LCoM = js.link.com_position(
model=model, data=data, link_index=i, in_link_frame=True
)
return m * B_H_W @ W_H_L[i] @ jnp.hstack([L_p_LCoM, 1])
com_links = jax.vmap(B_p̃_LCoM)(jnp.arange(model.number_of_links()))
B_p̃_CoM = (1 / m) * com_links.sum(axis=0)
B_p̃_CoM = B_p̃_CoM.at[3].set(1)
return (W_H_B @ B_p̃_CoM)[0:3].astype(float)
@jax.jit
@js.common.named_scope
def com_linear_velocity(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
r"""
Compute the linear velocity of the center of mass of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The linear velocity of the center of mass of the model in the
active representation.
Note:
The linear velocity of the center of mass is expressed in the mixed frame
:math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the
active velocity representation is either inertial-fixed or mixed,
and :math:`[C] = [B]` if the active velocity representation is body-fixed.
"""
# Extract the linear component of the 6D average centroidal velocity.
# This is expressed in G[B] in body-fixed representation, and in G[W] in
# inertial-fixed or mixed representation.
G_vl_WG = average_centroidal_velocity(model=model, data=data)[0:3]
return G_vl_WG
@jax.jit
@js.common.named_scope
def centroidal_momentum(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
r"""
Compute the centroidal momentum of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The centroidal momentum of the model.
Note:
The centroidal momentum is expressed in the mixed frame
:math:`({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`C = W` if the
active velocity representation is either inertial-fixed or mixed,
and :math:`C = B` if the active velocity representation is body-fixed.
"""
ν = data.generalized_velocity
G_J = centroidal_momentum_jacobian(model=model, data=data)
return G_J @ ν
@jax.jit
@js.common.named_scope
def centroidal_momentum_jacobian(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
r"""
Compute the Jacobian of the centroidal momentum of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The Jacobian of the centroidal momentum of the model.
Note:
The frame corresponding to the output representation of this Jacobian is either
:math:`G[W]`, if the active velocity representation is inertial-fixed or mixed,
or :math:`G[B]`, if the active velocity representation is body-fixed.
Note:
This Jacobian is also known in the literature as Centroidal Momentum Matrix.
"""
# Compute the Jacobian of the total momentum with body-fixed output representation.
# We convert the output representation either to G[W] or G[B] below.
B_Jh = js.model.total_momentum_jacobian(
model=model, data=data, output_vel_repr=VelRepr.Body
)
W_H_B = data._base_transform
B_H_W = jaxsim.math.Transform.inverse(W_H_B)
W_p_CoM = com_position(model=model, data=data)
match data.velocity_representation:
case VelRepr.Inertial | VelRepr.Mixed:
W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841
case VelRepr.Body:
W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841
case _:
raise ValueError(data.velocity_representation)
# Compute the transform for 6D forces.
G_Xf_B = jaxsim.math.Adjoint.from_transform(transform=B_H_W @ W_H_G).T
return G_Xf_B @ B_Jh
@jax.jit
@js.common.named_scope
def locked_centroidal_spatial_inertia(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
):
"""
Compute the locked centroidal spatial inertia of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The locked centroidal spatial inertia of the model.
"""
with data.switch_velocity_representation(VelRepr.Body):
B_Mbb_B = js.model.locked_spatial_inertia(model=model, data=data)
W_H_B = data._base_transform
W_p_CoM = com_position(model=model, data=data)
match data.velocity_representation:
case VelRepr.Inertial | VelRepr.Mixed:
W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841
case VelRepr.Body:
W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841
case _:
raise ValueError(data.velocity_representation)
B_H_G = jaxsim.math.Transform.inverse(W_H_B) @ W_H_G
B_Xv_G = jaxsim.math.Adjoint.from_transform(transform=B_H_G)
G_Xf_B = B_Xv_G.transpose()
return G_Xf_B @ B_Mbb_B @ B_Xv_G
@jax.jit
@js.common.named_scope
def average_centroidal_velocity(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
r"""
Compute the average centroidal velocity of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The average centroidal velocity of the model.
Note:
The average velocity is expressed in the mixed frame
:math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the
active velocity representation is either inertial-fixed or mixed,
and :math:`[C] = [B]` if the active velocity representation is body-fixed.
"""
ν = data.generalized_velocity
G_J = average_centroidal_velocity_jacobian(model=model, data=data)
return G_J @ ν
@jax.jit
@js.common.named_scope
def average_centroidal_velocity_jacobian(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
r"""
Compute the Jacobian of the average centroidal velocity of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The Jacobian of the average centroidal velocity of the model.
Note:
The frame corresponding to the output representation of this Jacobian is either
:math:`G[W]`, if the active velocity representation is inertial-fixed or mixed,
or :math:`G[B]`, if the active velocity representation is body-fixed.
"""
G_J = centroidal_momentum_jacobian(model=model, data=data)
G_Mbb = locked_centroidal_spatial_inertia(model=model, data=data)
return jnp.linalg.inv(G_Mbb) @ G_J
@jax.jit
@js.common.named_scope
def bias_acceleration(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
r"""
Compute the bias linear acceleration of the center of mass.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The bias linear acceleration of the center of mass in the active representation.
Note:
The bias acceleration is expressed in the mixed frame
:math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the
active velocity representation is either inertial-fixed or mixed,
and :math:`[C] = [B]` if the active velocity representation is body-fixed.
"""
# Compute the pose of all links with forward kinematics.
W_H_L = data._link_transforms
# Compute the bias acceleration of all links by zeroing the generalized velocity
# in the active representation.
v̇_bias_WL = js.model.link_bias_accelerations(model=model, data=data)
def other_representation_to_body(
C_v̇_WL: jtp.Vector, C_v_WC: jtp.Vector, L_H_C: jtp.Matrix, L_v_LC: jtp.Vector
) -> jtp.Vector:
"""
Convert the body-fixed representation of the link bias acceleration
C_v̇_WL expressed in a generic frame C to the body-fixed representation L_v̇_WL.
"""
L_X_C = jaxsim.math.Adjoint.from_transform(transform=L_H_C)
C_X_L = jaxsim.math.Adjoint.inverse(L_X_C)
L_v̇_WL = L_X_C @ (C_v̇_WL + jaxsim.math.Cross.vx(C_X_L @ L_v_LC) @ C_v_WC)
return L_v̇_WL
# We need here to get the body-fixed bias acceleration of the links.
# Since it's computed in the active representation, we need to convert it to body.
match data.velocity_representation:
case VelRepr.Body:
L_a_bias_WL = v̇_bias_WL
case VelRepr.Inertial:
C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841
C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
L_H_C = L_H_W = jax.vmap(jaxsim.math.Transform.inverse)(W_H_L) # noqa: F841
L_v_LC = L_v_LW = jax.vmap( # noqa: F841
lambda i: -js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
)
)(jnp.arange(model.number_of_links()))
L_a_bias_WL = jax.vmap(
lambda i: other_representation_to_body(
C_v̇_WL=C_v̇_WL[i],
C_v_WC=C_v_WC,
L_H_C=L_H_C[i],
L_v_LC=L_v_LC[i],
)
)(jnp.arange(model.number_of_links()))
case VelRepr.Mixed:
C_v̇_WL = LW_v̇_bias_WL = v̇_bias_WL # noqa: F841
C_v_WC = LW_v_W_LW = jax.vmap( # noqa: F841
lambda i: js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Mixed
)
.at[3:6]
.set(jnp.zeros(3))
)(jnp.arange(model.number_of_links()))
L_H_C = L_H_LW = jax.vmap( # noqa: F841
lambda W_H_L: jaxsim.math.Transform.inverse(
W_H_L.at[0:3, 3].set(jnp.zeros(3))
)
)(W_H_L)
L_v_LC = L_v_L_LW = jax.vmap( # noqa: F841
lambda i: -js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
)
.at[0:3]
.set(jnp.zeros(3))
)(jnp.arange(model.number_of_links()))
L_a_bias_WL = jax.vmap(
lambda i: other_representation_to_body(
C_v̇_WL=C_v̇_WL[i],
C_v_WC=C_v_WC[i],
L_H_C=L_H_C[i],
L_v_LC=L_v_LC[i],
)
)(jnp.arange(model.number_of_links()))
case _:
raise ValueError(data.velocity_representation)
# Compute the bias of the 6D momentum derivative.
def bias_momentum_derivative_term(
link_index: jtp.Int, L_a_bias_WL: jtp.Vector
) -> jtp.Vector:
# Get the body-fixed 6D inertia matrix.
L_M_L = js.link.spatial_inertia(model=model, link_index=link_index)
# Compute the body-fixed 6D velocity.
L_v_WL = js.link.velocity(
model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Body
)
# Compute the world-to-link transformations for 6D forces.
W_Xf_L = jaxsim.math.Adjoint.from_transform(
transform=W_H_L[link_index], inverse=True
).T
# Compute the contribution of the link to the bias acceleration of the CoM.
W_ḣ_bias_link_contribution = W_Xf_L @ (
L_M_L @ L_a_bias_WL + jaxsim.math.Cross.vx_star(L_v_WL) @ L_M_L @ L_v_WL
)
return W_ḣ_bias_link_contribution
# Sum the contributions of all links to the bias acceleration of the CoM.
W_ḣ_bias = jax.vmap(bias_momentum_derivative_term)(
jnp.arange(model.number_of_links()), L_a_bias_WL
).sum(axis=0)
# Compute the total mass of the model.
m = js.model.total_mass(model=model)
# Compute the position of the CoM.
W_p_CoM = com_position(model=model, data=data)
match data.velocity_representation:
# G := G[W] = (W_p_CoM, [W])
case VelRepr.Inertial | VelRepr.Mixed:
W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
GW_Xf_W = jaxsim.math.Adjoint.from_transform(W_H_GW).T
GW_ḣ_bias = GW_Xf_W @ W_ḣ_bias
GW_v̇l_com_bias = GW_ḣ_bias[0:3] / m
return GW_v̇l_com_bias
# G := G[B] = (W_p_CoM, [B])
case VelRepr.Body:
GB_Xf_W = jaxsim.math.Adjoint.from_transform(
transform=data._base_transform.at[0:3].set(W_p_CoM)
).T
GB_ḣ_bias = GB_Xf_W @ W_ḣ_bias
GB_v̇l_com_bias = GB_ḣ_bias[0:3] / m
return GB_v̇l_com_bias
case _:
raise ValueError(data.velocity_representation)
================================================
FILE: src/jaxsim/api/common.py
================================================
import abc
import contextlib
import dataclasses
import enum
import functools
from collections.abc import Callable, Iterator
from typing import ParamSpec, TypeVar
import jax
import jax.numpy as jnp
import jax_dataclasses
from jax_dataclasses import Static
import jaxsim.typing as jtp
from jaxsim.math import Adjoint
from jaxsim.utils import JaxsimDataclass, Mutability
try:
from typing import Self
except ImportError:
from typing_extensions import Self
_P = ParamSpec("_P")
_R = TypeVar("_R")
def named_scope(fn, name: str | None = None) -> Callable[_P, _R]:
"""Apply a JAX named scope to a function for improved profiling and clarity."""
@functools.wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
with jax.named_scope(name or fn.__name__):
return fn(*args, **kwargs)
return wrapper
@enum.unique
class VelRepr(enum.IntEnum):
"""
Enumeration of all supported 6D velocity representations.
"""
Body = enum.auto()
Mixed = enum.auto()
Inertial = enum.auto()
@jax_dataclasses.pytree_dataclass
class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
"""
Base class for model data structures with velocity representation.
"""
velocity_representation: Static[VelRepr] = dataclasses.field(
default=VelRepr.Inertial, kw_only=True
)
@contextlib.contextmanager
def switch_velocity_representation(
self, velocity_representation: VelRepr
) -> Iterator[Self]:
"""
Context manager to temporarily switch the velocity representation.
Args:
velocity_representation: The new velocity representation.
Yields:
The same object with the new velocity representation.
"""
original_representation = self.velocity_representation
try:
# First, we replace the velocity representation.
with self.mutable_context(
mutability=Mutability.MUTABLE_NO_VALIDATION,
restore_after_exception=True,
):
self.velocity_representation = velocity_representation
# Then, we yield the data with changed representation.
# We run this in a mutable context with restoration so that any exception
# occurring, we restore the original object in case it was modified.
with self.mutable_context(
mutability=self.mutability(), restore_after_exception=True
):
yield self
finally:
with self.mutable_context(
mutability=Mutability.MUTABLE_NO_VALIDATION,
restore_after_exception=True,
):
self.velocity_representation = original_representation
@staticmethod
@functools.partial(jax.jit, static_argnames=["other_representation", "is_force"])
def inertial_to_other_representation(
array: jtp.Array,
other_representation: VelRepr,
transform: jtp.Matrix,
*,
is_force: bool,
) -> jtp.Array:
r"""
Convert a 6D quantity from inertial-fixed to another representation.
Args:
array: The 6D quantity to convert.
other_representation: The representation to convert to.
transform:
The :math:`W \mathbf{H}_O` transform, where :math:`O` is the
reference frame of the other representation.
is_force: Whether the quantity is a 6D force or a 6D velocity.
Returns:
The 6D quantity in the other representation.
"""
W_array = array
W_H_O = transform
match other_representation:
case VelRepr.Inertial:
return W_array
case VelRepr.Body:
if not is_force:
O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True)
O_array = jnp.einsum("...ij,...j->...i", O_Xv_W, W_array)
else:
O_Xf_W = Adjoint.from_transform(transform=W_H_O).swapaxes(-1, -2)
O_array = jnp.einsum("...ij,...j->...i", O_Xf_W, W_array)
return O_array
case VelRepr.Mixed:
W_H_OW = W_H_O.at[..., 0:3, 0:3].set(jnp.eye(3))
if not is_force:
OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True)
OW_array = jnp.einsum("...ij,...j->...i", OW_Xv_W, W_array)
else:
OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).swapaxes(-1, -2)
OW_array = jnp.einsum("...ij,...j->...i", OW_Xf_W, W_array)
return OW_array
case _:
raise ValueError(other_representation)
@staticmethod
@functools.partial(jax.jit, static_argnames=["other_representation", "is_force"])
def other_representation_to_inertial(
array: jtp.Array,
other_representation: VelRepr,
transform: jtp.Matrix,
*,
is_force: bool,
) -> jtp.Array:
r"""
Convert a 6D quantity from another representation to inertial-fixed.
Args:
array: The 6D quantity to convert.
other_representation: The representation to convert from.
transform:
The `math:W \mathbf{H}_O` transform, where `math:O` is the
reference frame of the other representation.
is_force: Whether the quantity is a 6D force or a 6D velocity.
Returns:
The 6D quantity in the inertial-fixed representation.
"""
O_array = array
W_H_O = transform
match other_representation:
case VelRepr.Inertial:
return O_array
case VelRepr.Body:
if not is_force:
W_Xv_O = Adjoint.from_transform(W_H_O)
W_array = jnp.einsum("...ij,...j->...i", W_Xv_O, O_array)
else:
W_Xf_O = Adjoint.from_transform(
transform=W_H_O, inverse=True
).swapaxes(-1, -2)
W_array = jnp.einsum("...ij,...j->...i", W_Xf_O, O_array)
return W_array
case VelRepr.Mixed:
W_H_OW = W_H_O.at[..., 0:3, 0:3].set(jnp.eye(3))
if not is_force:
W_Xv_BW = Adjoint.from_transform(W_H_OW)
W_array = jnp.einsum("...ij,...j->...i", W_Xv_BW, O_array)
else:
W_Xf_BW = Adjoint.from_transform(
transform=W_H_OW, inverse=True
).swapaxes(-1, -2)
W_array = jnp.einsum("...ij,...j->...i", W_Xf_BW, O_array)
return W_array
case _:
raise ValueError(other_representation)
================================================
FILE: src/jaxsim/api/contact.py
================================================
from __future__ import annotations
import functools
import jax
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.exceptions
import jaxsim.typing as jtp
from jaxsim import logging
from jaxsim.math import Adjoint, Cross, Transform
from jaxsim.rbda.contacts import SoftContacts
from .common import VelRepr
@jax.jit
@js.common.named_scope
def collidable_point_kinematics(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> tuple[jtp.Matrix, jtp.Matrix]:
"""
Compute the position and 3D velocity of the collidable points in the world frame.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The position and velocity of the collidable points in the world frame.
Note:
The collidable point velocity is the plain coordinate derivative of the position.
If we attach a frame C = (p_C, [C]) to the collidable point, it corresponds to
the linear component of the mixed 6D frame velocity.
"""
W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
model=model,
link_transforms=data._link_transforms,
link_velocities=data._link_velocities,
)
return W_p_Ci, W_ṗ_Ci
@jax.jit
@js.common.named_scope
def collidable_point_positions(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
"""
Compute the position of the collidable points in the world frame.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The position of the collidable points in the world frame.
"""
W_p_Ci, _ = collidable_point_kinematics(model=model, data=data)
return W_p_Ci
@jax.jit
@js.common.named_scope
def collidable_point_velocities(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
"""
Compute the 3D velocity of the collidable points in the world frame.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The 3D velocity of the collidable points.
"""
_, W_ṗ_Ci = collidable_point_kinematics(model=model, data=data)
return W_ṗ_Ci
@functools.partial(jax.jit, static_argnames=["link_names"])
@js.common.named_scope
def in_contact(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_names: tuple[str, ...] | None = None,
) -> jtp.Vector:
"""
Return whether the links are in contact with the terrain.
Args:
model: The model to consider.
data: The data of the considered model.
link_names:
The names of the links to consider. If None, all links are considered.
Returns:
A boolean vector indicating whether the links are in contact with the terrain.
"""
if link_names is not None and set(link_names).difference(model.link_names()):
raise ValueError("One or more link names are not part of the model")
# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)
parent_link_idx_of_enabled_collidable_points = jnp.array(
model.kin_dyn_parameters.contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]
W_p_Ci = collidable_point_positions(model=model, data=data)
terrain_height = jax.vmap(lambda x, y: model.terrain.height(x=x, y=y))(
W_p_Ci[:, 0], W_p_Ci[:, 1]
)
below_terrain = W_p_Ci[:, 2] <= terrain_height
link_idxs = (
js.link.names_to_idxs(link_names=link_names, model=model)
if link_names is not None
else jnp.arange(model.number_of_links())
)
links_in_contact = jax.vmap(
lambda link_index: jnp.where(
parent_link_idx_of_enabled_collidable_points == link_index,
below_terrain,
jnp.zeros_like(below_terrain, dtype=bool),
).any()
)(link_idxs)
return links_in_contact
def estimate_good_soft_contacts_parameters(
*args, **kwargs
) -> jaxsim.rbda.contacts.ContactParamsTypes:
"""
Estimate good soft contacts parameters. Deprecated, use `estimate_good_contact_parameters` instead.
"""
msg = "This method is deprecated, please use `{}`."
logging.warning(msg.format(estimate_good_contact_parameters.__name__))
return estimate_good_contact_parameters(*args, **kwargs)
def estimate_good_contact_parameters(
model: js.model.JaxSimModel,
*,
standard_gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
static_friction_coefficient: jtp.FloatLike = 0.5,
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
damping_ratio: jtp.FloatLike = 1.0,
max_penetration: jtp.FloatLike | None = None,
) -> jaxsim.rbda.contacts.ContactParamsTypes:
"""
Estimate good contact parameters.
Args:
model: The model to consider.
standard_gravity: The standard gravity acceleration.
static_friction_coefficient: The static friction coefficient.
number_of_active_collidable_points_steady_state:
The number of active collidable points in steady state.
damping_ratio: The damping ratio.
max_penetration: The maximum penetration allowed.
Returns:
The estimated good contacts parameters.
Note:
This is primarily a convenience function for soft-like contact models.
However, it provides with some good default parameters also for the other ones.
Note:
This method provides a good set of contacts parameters.
The user is encouraged to fine-tune the parameters based on the
specific application.
"""
if max_penetration is None:
zero_data = js.data.JaxSimModelData.build(model=model)
W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
if model.floating_base():
W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
W_pz_CoM = W_pz_CoM - W_pz_C.min()
# Consider as default a 1% of the model center of mass height.
max_penetration = 0.01 * W_pz_CoM
nc = number_of_active_collidable_points_steady_state
return model.contact_model._parameters_class().build_default_from_jaxsim_model(
model=model,
standard_gravity=standard_gravity,
static_friction_coefficient=static_friction_coefficient,
max_penetration=max_penetration,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
)
@jax.jit
@js.common.named_scope
def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
r"""
Return the pose of the enabled collidable points.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The stacked SE(3) matrices of all enabled collidable points.
Note:
Each collidable point is implicitly associated with a frame
:math:`C = ({}^W p_C, [L])`, where :math:`{}^W p_C` is the position of the
collidable point and :math:`[L]` is the orientation frame of the link it is
rigidly attached to.
"""
# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)
parent_link_idx_of_enabled_collidable_points = jnp.array(
model.kin_dyn_parameters.contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]
# Get the transforms of the parent link of all collidable points.
W_H_L = data._link_transforms[parent_link_idx_of_enabled_collidable_points]
L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
indices_of_enabled_collidable_points
]
# Build the link-to-point transform from the displacement between the link frame L
# and the implicit contact frame C.
L_H_C = jax.vmap(jnp.eye(4).at[0:3, 3].set)(L_p_Ci)
# Compose the work-to-link and link-to-point transforms.
return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C)
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
@js.common.named_scope
def jacobian(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
output_vel_repr: VelRepr | None = None,
) -> jtp.Array:
r"""
Return the free-floating Jacobian of the enabled collidable points.
Args:
model: The model to consider.
data: The data of the considered model.
output_vel_repr:
The output velocity representation of the free-floating jacobian.
Returns:
The stacked :math:`6 \times (6+n)` free-floating jacobians of the frames associated to the
enabled collidable points.
Note:
Each collidable point is implicitly associated with a frame
:math:`C = ({}^W p_C, [L])`, where :math:`{}^W p_C` is the position of the
collidable point and :math:`[L]` is the orientation frame of the link it is
rigidly attached to.
"""
output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)
# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)
parent_link_idx_of_enabled_collidable_points = jnp.array(
model.kin_dyn_parameters.contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]
# Compute the Jacobians of all links.
W_J_WL = js.model.generalized_free_floating_jacobian(
model=model, data=data, output_vel_repr=VelRepr.Inertial
)
# Compute the contact Jacobian.
# In inertial-fixed output representation, the Jacobian of the parent link is also
# the Jacobian of the frame C implicitly associated with the collidable point.
W_J_WC = W_J_WL[parent_link_idx_of_enabled_collidable_points]
# Adjust the output representation.
match output_vel_repr:
case VelRepr.Inertial:
O_J_WC = W_J_WC
case VelRepr.Body:
W_H_C = transforms(model=model, data=data)
def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
C_X_W = jaxsim.math.Adjoint.from_transform(
transform=W_H_C, inverse=True
)
C_J_WC = C_X_W @ W_J_WC
return C_J_WC
O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC)
case VelRepr.Mixed:
W_H_C = transforms(model=model, data=data)
def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))
CW_X_W = jaxsim.math.Adjoint.from_transform(
transform=W_H_CW, inverse=True
)
CW_J_WC = CW_X_W @ W_J_WC
return CW_J_WC
O_J_WC = jax.vmap(mixed_jacobian)(W_H_C, W_J_WC)
case _:
raise ValueError(output_vel_repr)
return O_J_WC
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
@js.common.named_scope
def jacobian_derivative(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
output_vel_repr: VelRepr | None = None,
) -> jtp.Matrix:
r"""
Compute the derivative of the free-floating jacobian of the enabled collidable points.
Args:
model: The model to consider.
data: The data of the considered model.
output_vel_repr:
The output velocity representation of the free-floating jacobian derivative.
Returns:
The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the enabled collidable points.
Note:
The input representation of the free-floating jacobian derivative is the active
velocity representation.
"""
output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)
# Get the index of the parent link and the position of the collidable point.
parent_link_idx_of_enabled_collidable_points = jnp.array(
model.kin_dyn_parameters.contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]
L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
indices_of_enabled_collidable_points
]
# Get the transforms of all the parent links.
W_H_Li = data._link_transforms
# Get the link velocities.
W_v_WLi = data._link_velocities
# =====================================================
# Compute quantities to adjust the input representation
# =====================================================
def compute_T(model: js.model.JaxSimModel, X: jtp.Matrix) -> jtp.Matrix:
In = jnp.eye(model.dofs())
T = jax.scipy.linalg.block_diag(X, In)
return T
def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix:
On = jnp.zeros(shape=(model.dofs(), model.dofs()))
Ṫ = jax.scipy.linalg.block_diag(Ẋ, On)
return Ṫ
# Compute the operator to change the representation of ν, and its
# time derivative.
match data.velocity_representation:
case VelRepr.Inertial:
W_H_W = jnp.eye(4)
W_X_W = Adjoint.from_transform(transform=W_H_W)
W_Ẋ_W = jnp.zeros((6, 6))
T = compute_T(model=model, X=W_X_W)
Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W)
case VelRepr.Body:
W_H_B = data._base_transform
W_X_B = Adjoint.from_transform(transform=W_H_B)
B_v_WB = data.base_velocity
B_vx_WB = Cross.vx(B_v_WB)
W_Ẋ_B = W_X_B @ B_vx_WB
T = compute_T(model=model, X=W_X_B)
Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B)
case VelRepr.Mixed:
W_H_B = data._base_transform
W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
W_X_BW = Adjoint.from_transform(transform=W_H_BW)
BW_v_WB = data.base_velocity
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
BW_vx_W_BW = Cross.vx(BW_v_W_BW)
W_Ẋ_BW = W_X_BW @ BW_vx_W_BW
T = compute_T(model=model, X=W_X_BW)
Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW)
case _:
raise ValueError(data.velocity_representation)
# =====================================================
# Compute quantities to adjust the output representation
# =====================================================
with data.switch_velocity_representation(VelRepr.Inertial):
# Compute the Jacobian of the parent link in inertial representation.
W_J_WL_W = js.model.generalized_free_floating_jacobian(
model=model,
data=data,
)
# Compute the Jacobian derivative of the parent link in inertial representation.
W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative(
model=model,
data=data,
)
def compute_O_J̇_WC_I(
L_p_C: jtp.Vector,
parent_link_idx: jtp.Int,
W_H_L: jtp.Matrix,
) -> jtp.Matrix:
match output_vel_repr:
case VelRepr.Inertial:
O_X_W = W_X_W = Adjoint.from_transform( # noqa: F841
transform=jnp.eye(4)
)
O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6)) # noqa: F841
case VelRepr.Body:
L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
W_H_C = W_H_L[parent_link_idx] @ L_H_C
O_X_W = C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
W_v_WC = W_v_WLi[parent_link_idx]
W_vx_WC = Cross.vx(W_v_WC)
O_Ẋ_W = C_Ẋ_W = -C_X_W @ W_vx_WC # noqa: F841
case VelRepr.Mixed:
L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
W_H_C = W_H_L[parent_link_idx] @ L_H_C
W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))
CW_H_W = Transform.inverse(W_H_CW)
O_X_W = CW_X_W = Adjoint.from_transform(transform=CW_H_W)
CW_v_WC = CW_X_W @ W_v_WLi[parent_link_idx]
W_v_W_CW = jnp.zeros(6).at[0:3].set(CW_v_WC[0:3])
W_vx_W_CW = Cross.vx(W_v_W_CW)
O_Ẋ_W = CW_Ẋ_W = -CW_X_W @ W_vx_W_CW # noqa: F841
case _:
raise ValueError(output_vel_repr)
O_J̇_WC_I = jnp.zeros(shape=(6, 6 + model.dofs()))
O_J̇_WC_I += O_Ẋ_W @ W_J_WL_W[parent_link_idx] @ T
O_J̇_WC_I += O_X_W @ W_J̇_WL_W[parent_link_idx] @ T
O_J̇_WC_I += O_X_W @ W_J_WL_W[parent_link_idx] @ Ṫ
return O_J̇_WC_I
O_J̇_WC = jax.vmap(compute_O_J̇_WC_I, in_axes=(0, 0, None))(
L_p_Ci, parent_link_idx_of_enabled_collidable_points, W_H_Li
)
return O_J̇_WC
@jax.jit
@js.common.named_scope
def link_contact_forces(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_forces: jtp.MatrixLike | None = None,
joint_torques: jtp.VectorLike | None = None,
) -> tuple[jtp.Matrix, dict[str, jtp.Matrix]]:
"""
Compute the 6D contact forces of all links of the model in inertial representation.
Args:
model: The model to consider.
data: The data of the considered model.
link_forces:
The 6D external forces to apply to the links expressed in inertial representation
joint_torques:
The joint torques acting on the joints.
Returns:
A `(nL, 6)` array containing the stacked 6D contact forces of the links,
expressed in inertial representation.
"""
# Compute the contact forces for each collidable point with the active contact model.
W_f_C, aux_dict = model.contact_model.compute_contact_forces(
model=model,
data=data,
**(
dict(link_forces=link_forces, joint_force_references=joint_torques)
if not isinstance(model.contact_model, SoftContacts)
else {}
),
)
# Compute the 6D forces applied to the links equivalent to the forces applied
# to the frames associated to the collidable points.
W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C)
return W_f_L, aux_dict
def link_forces_from_contact_forces(
model: js.model.JaxSimModel,
*,
contact_forces: jtp.MatrixLike,
) -> jtp.Matrix:
"""
Compute the link forces from the contact forces.
Args:
model: The robot model considered by the contact model.
contact_forces: The contact forces computed by the contact model.
Returns:
The 6D contact forces applied to the links and expressed in the frame of
the velocity representation of data.
"""
# Get the object storing the contact parameters of the model.
contact_parameters = model.kin_dyn_parameters.contact_parameters
# Extract the indices corresponding to the enabled collidable points.
indices_of_enabled_collidable_points = (
contact_parameters.indices_of_enabled_collidable_points
)
# Convert the contact forces to a JAX array.
W_f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())
# Construct the vector defining the parent link index of each collidable point.
# We use this vector to sum the 6D forces of all collidable points rigidly
# attached to the same link.
parent_link_index_of_collidable_points = jnp.array(
contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]
# Create the mask that associate each collidable point to their parent link.
# We use this mask to sum the collidable points to the right link.
mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
model.number_of_links()
)
# Sum the forces of all collidable points rigidly attached to a body.
# Since the contact forces W_f_C are expressed in the world frame,
# we don't need any coordinate transformation.
W_f_L = mask.T @ W_f_C
return W_f_L
================================================
FILE: src/jaxsim/api/data.py
================================================
from __future__ import annotations
import dataclasses
import functools
from collections.abc import Sequence
try:
from typing import Self, override
except ImportError:
from typing_extensions import override, Self
import jax
import jax.numpy as jnp
import jax.scipy.spatial.transform
import jax_dataclasses
import jaxsim.api as js
import jaxsim.math
import jaxsim.rbda
import jaxsim.typing as jtp
from . import common
from .common import VelRepr
@jax_dataclasses.pytree_dataclass
class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
"""
Class storing the state of the physics model dynamics.
Attributes:
joint_positions: The vector of joint positions.
joint_velocities: The vector of joint velocities.
base_position: The 3D position of the base link.
base_quaternion: The quaternion defining the orientation of the base link.
base_linear_velocity:
The linear velocity of the base link in inertial-fixed representation.
base_angular_velocity:
The angular velocity of the base link in inertial-fixed representation.
base_transform: The base transform.
joint_transforms: The joint transforms.
link_transforms: The link transforms.
link_velocities: The link velocities in inertial-fixed representation.
"""
# Joint state
_joint_positions: jtp.Vector
_joint_velocities: jtp.Vector
# Base state
_base_quaternion: jtp.Vector
_base_linear_velocity: jtp.Vector
_base_angular_velocity: jtp.Vector
_base_position: jtp.Vector
# Cached computations.
_base_transform: jtp.Matrix = dataclasses.field(repr=False, default=None)
_joint_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
_link_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
_link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None)
# Extended state for soft and rigid contact models.
contact_state: dict[str, jtp.Array] = dataclasses.field(default_factory=dict)
@staticmethod
def build(
model: js.model.JaxSimModel,
base_position: jtp.VectorLike | None = None,
base_quaternion: jtp.VectorLike | None = None,
joint_positions: jtp.VectorLike | None = None,
base_linear_velocity: jtp.VectorLike | None = None,
base_angular_velocity: jtp.VectorLike | None = None,
joint_velocities: jtp.VectorLike | None = None,
contact_state: dict[str, jtp.Array] | None = None,
velocity_representation: VelRepr = VelRepr.Mixed,
) -> JaxSimModelData:
"""
Create a `JaxSimModelData` object with the given state.
Args:
model: The model for which to create the state.
base_position: The base position.
base_quaternion: The base orientation as a quaternion.
joint_positions: The joint positions.
base_linear_velocity:
The base linear velocity in the selected representation.
base_angular_velocity:
The base angular velocity in the selected representation.
joint_velocities: The joint velocities.
velocity_representation: The velocity representation to use. It defaults to mixed if not provided.
contact_state: The optional contact state.
Returns:
A `JaxSimModelData` initialized with the given state.
"""
base_position = jnp.array(
base_position if base_position is not None else jnp.zeros(3),
dtype=float,
).squeeze()
base_quaternion = jnp.array(
(
base_quaternion
if base_quaternion is not None
else jnp.array([1.0, 0, 0, 0])
),
dtype=float,
).squeeze()
base_linear_velocity = jnp.array(
base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3),
dtype=float,
).squeeze()
base_angular_velocity = jnp.array(
(
base_angular_velocity
if base_angular_velocity is not None
else jnp.zeros(3)
),
dtype=float,
).squeeze()
joint_positions = jnp.atleast_1d(
jnp.array(
(
joint_positions
if joint_positions is not None
else jnp.zeros(model.dofs())
),
dtype=float,
).squeeze()
)
joint_velocities = jnp.atleast_1d(
jnp.array(
(
joint_velocities
if joint_velocities is not None
else jnp.zeros(model.dofs())
),
dtype=float,
).squeeze()
)
W_H_B = jaxsim.math.Transform.from_quaternion_and_translation(
translation=base_position, quaternion=base_quaternion
)
W_v_WB = JaxSimModelData.other_representation_to_inertial(
array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
other_representation=velocity_representation,
transform=W_H_B,
is_force=False,
).astype(float)
joint_transforms = model.kin_dyn_parameters.joint_transforms(
joint_positions=joint_positions, base_transform=W_H_B
)
link_transforms, link_velocities_inertial = (
jaxsim.rbda.forward_kinematics_model(
model=model,
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
base_linear_velocity_inertial=W_v_WB[0:3],
base_angular_velocity_inertial=W_v_WB[3:6],
joint_velocities=joint_velocities,
joint_transforms=joint_transforms,
)
)
contact_state = contact_state or {}
if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts):
contact_state["tangential_deformation"] = contact_state.get(
"tangential_deformation",
jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.point),
)
model_data = JaxSimModelData(
velocity_representation=velocity_representation,
_base_quaternion=base_quaternion,
_base_position=base_position,
_joint_positions=joint_positions,
_base_linear_velocity=W_v_WB[0:3],
_base_angular_velocity=W_v_WB[3:6],
_joint_velocities=joint_velocities,
_base_transform=W_H_B,
_joint_transforms=joint_transforms,
_link_transforms=link_transforms,
_link_velocities=link_velocities_inertial,
contact_state=contact_state,
)
if not model_data.valid(model=model):
raise ValueError(
"The built state is not compatible with the model.", model_data
)
return model_data
@staticmethod
def zero(
model: js.model.JaxSimModel,
velocity_representation: VelRepr = VelRepr.Mixed,
) -> JaxSimModelData:
"""
Create a `JaxSimModelData` object with zero state.
Args:
model: The model for which to create the state.
velocity_representation: The velocity representation to use. It defaults to mixed if not provided.
Returns:
A `JaxSimModelData` initialized with zero state.
"""
return JaxSimModelData.build(
model=model, velocity_representation=velocity_representation
)
# ==================
# Extract quantities
# ==================
@property
def joint_positions(self) -> jtp.Vector:
"""
Get the joint positions.
Returns:
The joint positions.
"""
return self._joint_positions
@property
def joint_velocities(self) -> jtp.Vector:
"""
Get the joint velocities.
Returns:
The joint velocities.
"""
return self._joint_velocities
@property
def base_quaternion(self) -> jtp.Vector:
"""
Get the base quaternion.
Returns:
The base quaternion.
"""
return self._base_quaternion
@property
def base_position(self) -> jtp.Vector:
"""
Get the base position.
Returns:
The base position.
"""
return self._base_position
@property
def base_orientation(self) -> jtp.Matrix:
"""
Get the base orientation.
Returns:
The base orientation.
"""
# Extract the base quaternion.
W_Q_B = self.base_quaternion
# Always normalize the quaternion to avoid numerical issues.
# If the active scheme does not integrate the quaternion on its manifold,
# we introduce a Baumgarte stabilization to let the quaternion converge to
# a unit quaternion. In this case, it is not guaranteed that the quaternion
# stored in the state is a unit quaternion.
norm = jaxsim.math.safe_norm(W_Q_B, axis=-1, keepdims=True)
W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))
return W_Q_B
@property
def base_velocity(self) -> jtp.Vector:
"""
Get the base 6D velocity.
Returns:
The base 6D velocity in the active representation.
"""
W_v_WB = jnp.concatenate(
[self._base_linear_velocity, self._base_angular_velocity], axis=-1
)
W_H_B = self._base_transform
return (
JaxSimModelData.inertial_to_other_representation(
array=W_v_WB,
other_representation=self.velocity_representation,
transform=W_H_B,
is_force=False,
)
.squeeze()
.astype(float)
)
@property
def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:
r"""
Get the generalized position
:math:`\mathbf{q} = ({}^W \mathbf{H}_B, \mathbf{s}) \in \text{SO}(3) \times \mathbb{R}^n`.
Returns:
A tuple containing the base transform and the joint positions.
"""
return self._base_transform, self.joint_positions
@property
def generalized_velocity(self) -> jtp.Vector:
r"""
Get the generalized velocity.
:math:`\boldsymbol{\nu} = (\boldsymbol{v}_{W,B};\, \boldsymbol{\omega}_{W,B};\, \mathbf{s}) \in \mathbb{R}^{6+n}`
Returns:
The generalized velocity in the active representation.
"""
return (
jnp.hstack([self.base_velocity, self.joint_velocities])
.squeeze()
.astype(float)
)
@property
def base_transform(self) -> jtp.Matrix:
"""
Get the base transform.
Returns:
The base transform.
"""
return self._base_transform
# ================
# Store quantities
# ================
@js.common.named_scope
@jax.jit
def reset_base_quaternion(
self, model: js.model.JaxSimModel, base_quaternion: jtp.VectorLike
) -> Self:
"""
Reset the base quaternion.
Args:
model: The JaxSim model to use.
base_quaternion: The base orientation as a quaternion.
Returns:
The updated `JaxSimModelData` object.
"""
W_Q_B = jnp.array(base_quaternion, dtype=float)
norm = jaxsim.math.safe_norm(W_Q_B, axis=-1)
W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))
return self.replace(model=model, base_quaternion=W_Q_B)
@js.common.named_scope
@jax.jit
def reset_base_pose(
self, model: js.model.JaxSimModel, base_pose: jtp.MatrixLike
) -> Self:
"""
Reset the base pose.
Args:
model: The JaxSim model to use.
base_pose: The base pose as an SE(3) matrix.
Returns:
The updated `JaxSimModelData` object.
"""
base_pose = jnp.array(base_pose)
W_p_B = base_pose[0:3, 3]
W_Q_B = jaxsim.math.Quaternion.from_dcm(dcm=base_pose[0:3, 0:3])
return self.replace(
model=model,
base_position=W_p_B,
base_quaternion=W_Q_B,
)
@override
def replace(
self,
model: js.model.JaxSimModel,
joint_positions: jtp.Vector | None = None,
joint_velocities: jtp.Vector | None = None,
base_quaternion: jtp.Vector | None = None,
base_linear_velocity: jtp.Vector | None = None,
base_angular_velocity: jtp.Vector | None = None,
base_position: jtp.Vector | None = None,
*,
contact_state: dict[str, jtp.Array] | None = None,
validate: bool = False,
) -> Self:
"""
Replace the attributes of the `JaxSimModelData` object.
"""
if joint_positions is None:
joint_positions = self.joint_positions
if joint_velocities is None:
joint_velocities = self.joint_velocities
if base_quaternion is None:
base_quaternion = self.base_quaternion
if base_position is None:
base_position = self.base_position
if contact_state is None:
contact_state = self.contact_state
# Normalize the quaternion to avoid numerical issues.
base_quaternion_norm = jaxsim.math.safe_norm(
base_quaternion, axis=-1, keepdims=True
)
base_quaternion = base_quaternion / jnp.where(
base_quaternion_norm == 0, 1.0, base_quaternion_norm
)
joint_positions = jnp.atleast_2d(joint_positions.squeeze()).astype(float)
joint_velocities = jnp.atleast_2d(joint_velocities.squeeze()).astype(float)
base_quaternion = jnp.atleast_2d(base_quaternion.squeeze()).astype(float)
base_position = jnp.atleast_2d(base_position.squeeze()).astype(float)
base_transform = jaxsim.math.Transform.from_quaternion_and_translation(
translation=base_position, quaternion=base_quaternion
).reshape((-1, 4, 4))
joint_transforms = jax.vmap(model.kin_dyn_parameters.joint_transforms)(
joint_positions=joint_positions,
base_transform=base_transform,
)
if base_linear_velocity is None and base_angular_velocity is None:
base_linear_velocity_inertial = self._base_linear_velocity
base_angular_velocity_inertial = self._base_angular_velocity
else:
if base_linear_velocity is None:
base_linear_velocity = self.base_velocity[:3]
if base_angular_velocity is None:
base_angular_velocity = self.base_velocity[3:]
base_linear_velocity = jnp.atleast_1d(base_linear_velocity.squeeze())
base_angular_velocity = jnp.atleast_1d(base_angular_velocity.squeeze())
W_v_WB = JaxSimModelData.other_representation_to_inertial(
array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
other_representation=self.velocity_representation,
transform=base_transform,
is_force=False,
).astype(float)
base_linear_velocity_inertial, base_angular_velocity_inertial = (
W_v_WB[..., :3],
W_v_WB[..., 3:],
)
link_transforms, link_velocities = jax.vmap(
jaxsim.rbda.forward_kinematics_model, in_axes=(None,)
)(
model,
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
joint_velocities=joint_velocities,
base_linear_velocity_inertial=jnp.atleast_2d(base_linear_velocity_inertial),
base_angular_velocity_inertial=jnp.atleast_2d(
base_angular_velocity_inertial
),
joint_transforms=joint_transforms,
)
# Adjust the output shapes.
joint_positions = joint_positions.reshape(self._joint_positions.shape)
joint_velocities = joint_velocities.reshape(self._joint_velocities.shape)
base_quaternion = base_quaternion.reshape(self._base_quaternion.shape)
base_linear_velocity_inertial = base_linear_velocity_inertial.reshape(
self._base_linear_velocity.shape
)
base_angular_velocity_inertial = base_angular_velocity_inertial.reshape(
self._base_angular_velocity.shape
)
base_position = base_position.reshape(self._base_position.shape)
base_transform = base_transform.reshape(self._base_transform.shape)
joint_transforms = joint_transforms.reshape(self._joint_transforms.shape)
link_transforms = link_transforms.reshape(self._link_transforms.shape)
link_velocities = link_velocities.reshape(self._link_velocities.shape)
return super().replace(
_joint_positions=joint_positions,
_joint_velocities=joint_velocities,
_base_quaternion=base_quaternion,
_base_linear_velocity=base_linear_velocity_inertial,
_base_angular_velocity=base_angular_velocity_inertial,
_base_position=base_position,
_base_transform=base_transform,
_joint_transforms=joint_transforms,
_link_transforms=link_transforms,
_link_velocities=link_velocities,
contact_state=contact_state,
validate=validate,
)
def valid(self, model: js.model.JaxSimModel) -> bool:
"""
Check if the `JaxSimModelData` is valid for a given `JaxSimModel`.
Args:
model: The `JaxSimModel` to validate the `JaxSimModelData` against.
Returns:
`True` if the `JaxSimModelData` is valid for the given model,
`False` otherwise.
"""
if self._joint_positions.shape != (model.dofs(),):
return False
if self._joint_velocities.shape != (model.dofs(),):
return False
if self._base_position.shape != (3,):
return False
if self._base_quaternion.shape != (4,):
return False
if self._base_linear_velocity.shape != (3,):
return False
if self._base_angular_velocity.shape != (3,):
return False
return True
@functools.partial(jax.jit, static_argnames=["velocity_representation", "base_rpy_seq"])
def random_model_data(
model: js.model.JaxSimModel,
*,
key: jax.Array | None = None,
velocity_representation: VelRepr | None = None,
base_pos_bounds: tuple[
jtp.FloatLike | Sequence[jtp.FloatLike],
jtp.FloatLike | Sequence[jtp.FloatLike],
] = ((-1, -1, 0.5), 1.0),
base_rpy_bounds: tuple[
jtp.FloatLike | Sequence[jtp.FloatLike],
jtp.FloatLike | Sequence[jtp.FloatLike],
] = (-jnp.pi, jnp.pi),
base_rpy_seq: str = "XYZ",
joint_pos_bounds: (
tuple[
jtp.FloatLike | Sequence[jtp.FloatLike],
jtp.FloatLike | Sequence[jtp.FloatLike],
]
| None
) = None,
base_vel_lin_bounds: tuple[
jtp.FloatLike | Sequence[jtp.FloatLike],
jtp.FloatLike | Sequence[jtp.FloatLike],
] = (-1.0, 1.0),
base_vel_ang_bounds: tuple[
jtp.FloatLike | Sequence[jtp.FloatLike],
jtp.FloatLike | Sequence[jtp.FloatLike],
] = (-1.0, 1.0),
joint_vel_bounds: tuple[
jtp.FloatLike | Sequence[jtp.FloatLike],
jtp.FloatLike | Sequence[jtp.FloatLike],
] = (-1.0, 1.0),
) -> JaxSimModelData:
"""
Randomly generate a `JaxSimModelData` object.
Args:
model: The target model for the random data.
key: The random key.
velocity_representation: The velocity representation to use.
base_pos_bounds: The bounds for the base position.
base_rpy_bounds:
The bounds for the euler angles used to build the base orientation.
base_rpy_seq:
The sequence of axes for rotation (using `Rotation` from scipy).
joint_pos_bounds:
The bounds for the joint positions (reading the joint limits if None).
base_vel_lin_bounds: The bounds for the base linear velocity.
base_vel_ang_bounds: The bounds for the base angular velocity.
joint_vel_bounds: The bounds for the joint velocities.
Returns:
A `JaxSimModelData` object with random data.
"""
key = key if key is not None else jax.random.PRNGKey(seed=0)
k1, k2, k3, k4, k5, k6 = jax.random.split(key, num=6)
p_min = jnp.array(base_pos_bounds[0], dtype=float)
p_max = jnp.array(base_pos_bounds[1], dtype=float)
rpy_min = jnp.array(base_rpy_bounds[0], dtype=float)
rpy_max = jnp.array(base_rpy_bounds[1], dtype=float)
v_min = jnp.array(base_vel_lin_bounds[0], dtype=float)
v_max = jnp.array(base_vel_lin_bounds[1], dtype=float)
ω_min = jnp.array(base_vel_ang_bounds[0], dtype=float)
ω_max = jnp.array(base_vel_ang_bounds[1], dtype=float)
ṡ_min, ṡ_max = joint_vel_bounds
base_position = jax.random.uniform(key=k1, shape=(3,), minval=p_min, maxval=p_max)
base_quaternion = jaxsim.math.Quaternion.to_wxyz(
xyzw=jax.scipy.spatial.transform.Rotation.from_euler(
seq=base_rpy_seq,
angles=jax.random.uniform(
key=k2, shape=(3,), minval=rpy_min, maxval=rpy_max
),
).as_quat()
)
(
joint_positions,
joint_velocities,
base_linear_velocity,
base_angular_velocity,
) = (None,) * 4
if model.number_of_joints() > 0:
s_min, s_max = (
jnp.array(joint_pos_bounds, dtype=float)
if joint_pos_bounds is not None
else (None, None)
)
joint_positions = (
js.joint.random_joint_positions(model=model, key=k3)
if (s_min is None or s_max is None)
else jax.random.uniform(
key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max
)
)
joint_velocities = jax.random.uniform(
key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max
)
if model.floating_base():
base_linear_velocity = jax.random.uniform(
key=k5, shape=(3,), minval=v_min, maxval=v_max
)
base_angular_velocity = jax.random.uniform(
key=k6, shape=(3,), minval=ω_min, maxval=ω_max
)
return JaxSimModelData.build(
model=model,
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
joint_velocities=joint_velocities,
base_linear_velocity=base_linear_velocity,
base_angular_velocity=base_angular_velocity,
**(
{"velocity_representation": velocity_representation}
if velocity_representation is not None
else {}
),
)
================================================
FILE: src/jaxsim/api/frame.py
================================================
import functools
from collections.abc import Sequence
import jax
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim import exceptions
from jaxsim.math import Adjoint, Cross
from .common import VelRepr
# =======================
# Index-related functions
# =======================
@jax.jit
@js.common.named_scope
def idx_of_parent_link(
model: js.model.JaxSimModel, *, frame_index: jtp.IntLike
) -> jtp.Int:
"""
Get the index of the link to which the frame is rigidly attached.
Args:
model: The model to consider.
frame_index: The index of the frame.
Returns:
The index of the frame's parent link.
"""
n_l = model.number_of_links()
n_f = len(model.frame_names())
exceptions.raise_value_error_if(
condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
msg="Invalid frame index '{idx}'",
idx=frame_index,
)
return jnp.array(model.kin_dyn_parameters.frame_parameters.body)[
frame_index - model.number_of_links()
]
@functools.partial(jax.jit, static_argnames="frame_name")
@js.common.named_scope
def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> jtp.Int:
"""
Convert the name of a frame to its index.
Args:
model: The model to consider.
frame_name: The name of the frame.
Returns:
The index of the frame.
"""
if frame_name not in model.frame_names():
raise ValueError(f"Frame '{frame_name}' not found in the model.")
return (
jnp.array(
model.number_of_links()
+ model.kin_dyn_parameters.frame_parameters.name.index(frame_name)
)
.astype(int)
.squeeze()
)
def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str:
"""
Convert the index of a frame to its name.
Args:
model: The model to consider.
frame_index: The index of the frame.
Returns:
The name of the frame.
"""
n_l = model.number_of_links()
n_f = len(model.frame_names())
exceptions.raise_value_error_if(
condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
msg="Invalid frame index '{idx}'",
idx=frame_index,
)
return model.kin_dyn_parameters.frame_parameters.name[
frame_index - model.number_of_links()
]
@functools.partial(jax.jit, static_argnames=["frame_names"])
@js.common.named_scope
def names_to_idxs(
model: js.model.JaxSimModel, *, frame_names: Sequence[str]
) -> jax.Array:
"""
Convert a sequence of frame names to their corresponding indices.
Args:
model: The model to consider.
frame_names: The names of the frames.
Returns:
The indices of the frames.
"""
return jnp.array(
[name_to_idx(model=model, frame_name=name) for name in frame_names]
).astype(int)
def idxs_to_names(
model: js.model.JaxSimModel, *, frame_indices: Sequence[jtp.IntLike]
) -> tuple[str, ...]:
"""
Convert a sequence of frame indices to their corresponding names.
Args:
model: The model to consider.
frame_indices: The indices of the frames.
Returns:
The names of the frames.
"""
return tuple(idx_to_name(model=model, frame_index=idx) for idx in frame_indices)
# ==========
# Frame APIs
# ==========
@jax.jit
@js.common.named_scope
def transform(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
frame_index: jtp.IntLike,
) -> jtp.Matrix:
"""
Compute the SE(3) transform from the world frame to the specified frame.
Args:
model: The model to consider.
data: The data of the considered model.
frame_index: The index of the frame for which the transform is requested.
Returns:
The 4x4 matrix representing the transform.
"""
n_l = model.number_of_links()
n_f = len(model.frame_names())
exceptions.raise_value_error_if(
condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
msg="Invalid frame index '{idx}'",
idx=frame_index,
)
# Compute the necessary transforms.
L = idx_of_parent_link(model=model, frame_index=frame_index)
W_H_L = js.link.transform(model=model, data=data, link_index=L)
# Get the static frame pose wrt the parent link.
L_H_F = model.kin_dyn_parameters.frame_parameters.transform[
frame_index - model.number_of_links()
]
# Combine the transforms computing the frame pose.
return W_H_L @ L_H_F
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
@js.common.named_scope
def velocity(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
frame_index: jtp.IntLike,
output_vel_repr: VelRepr | None = None,
) -> jtp.Vector:
"""
Compute the 6D velocity of the frame.
Args:
model: The model to consider.
data: The data of the considered model.
frame_index: The index of the frame.
output_vel_repr:
The output velocity representation of the frame velocity.
Returns:
The 6D velocity of the frame in the specified velocity representation.
"""
n_l = model.number_of_links()
n_f = model.number_of_frames()
exceptions.raise_value_error_if(
condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
msg="Invalid frame index '{idx}'",
idx=frame_index,
)
output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)
# Get the frame jacobian having I as input representation (taken from data)
# and O as output representation, specified by the user (or taken from data).
O_J_WF_I = jacobian(
model=model,
data=data,
frame_index=frame_index,
output_vel_repr=output_vel_repr,
)
# Get the generalized velocity in the input velocity representation.
I_ν = data.generalized_velocity
# Compute the frame velocity in the output velocity representation.
return O_J_WF_I @ I_ν
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
@js.common.named_scope
def jacobian(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
frame_index: jtp.IntLike,
output_vel_repr: VelRepr | None = None,
) -> jtp.Matrix:
r"""
Compute the free-floating jacobian of the frame.
Args:
model: The model to consider.
data: The data of the considered model.
frame_index: The index of the frame.
output_vel_repr:
The output velocity representation of the free-floating jacobian.
Returns:
The :math:`6 \times (6+n)` free-floating jacobian of the frame.
Note:
The input representation of the free-floating jacobian is the active
velocity representation.
"""
n_l = model.number_of_links()
n_f = model.number_of_frames()
exceptions.raise_value_error_if(
condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
msg="Invalid frame index '{idx}'",
idx=frame_index,
)
output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)
# Get the index of the parent link.
L = idx_of_parent_link(model=model, frame_index=frame_index)
# Compute only the parent-link body Jacobian.
L_J_WL = js.link.jacobian(
model=model,
data=data,
link_index=L,
output_vel_repr=VelRepr.Body,
)
W_H_L = data._link_transforms[L]
L_H_F = model.kin_dyn_parameters.frame_parameters.transform[
frame_index - model.number_of_links()
]
L_p_F = L_H_F[0:3, 3]
# Adjust the output representation.
match output_vel_repr:
case VelRepr.Inertial:
W_X_L = Adjoint.from_rotation_and_translation(
rotation=W_H_L[0:3, 0:3],
translation=W_H_L[0:3, 3],
)
W_J_WL = W_X_L @ L_J_WL
O_J_WL_I = W_J_WL
case VelRepr.Body:
F_X_L = Adjoint.from_rotation_and_translation(
rotation=L_H_F[0:3, 0:3],
translation=L_p_F,
inverse=True,
)
F_J_WL = F_X_L @ L_J_WL
O_J_WL_I = F_J_WL
case VelRepr.Mixed:
W_R_L = W_H_L[0:3, 0:3]
FW_X_L = Adjoint.from_rotation_and_translation(
rotation=W_R_L,
translation=-W_R_L @ L_p_F,
)
FW_J_WL = FW_X_L @ L_J_WL
O_J_WL_I = FW_J_WL
case _:
raise ValueError(output_vel_repr)
return O_J_WL_I
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
@js.common.named_scope
def jacobian_derivative(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
frame_index: jtp.IntLike,
output_vel_repr: VelRepr | None = None,
) -> jtp.Matrix:
r"""
Compute the derivative of the free-floating jacobian of the frame.
Args:
model: The model to consider.
data: The data of the considered model.
frame_index: The index of the frame.
output_vel_repr:
The output velocity representation of the free-floating jacobian derivative.
Returns:
The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the frame.
Note:
The input representation of the free-floating jacobian derivative is the active
velocity representation.
"""
n_l = model.number_of_links()
n_f = len(model.frame_names())
exceptions.raise_value_error_if(
condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),
msg="Invalid frame index '{idx}'",
idx=frame_index,
)
output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)
# Get the index of the parent link.
L = idx_of_parent_link(model=model, frame_index=frame_index)
W_J_WL_I = js.link.jacobian(
model=model,
data=data,
link_index=L,
output_vel_repr=VelRepr.Inertial,
)
W_J̇_WL_I = js.link.jacobian_derivative(
model=model,
data=data,
link_index=L,
output_vel_repr=VelRepr.Inertial,
)
W_H_L = data._link_transforms[L]
L_H_F = model.kin_dyn_parameters.frame_parameters.transform[
frame_index - model.number_of_links()
]
W_H_F = W_H_L @ L_H_F
# =====================================================
# Compute quantities to adjust the output representation
# =====================================================
W_v_WF = W_J_WL_I @ data.generalized_velocity
match output_vel_repr:
case VelRepr.Inertial:
O_X_W = jnp.eye(6, dtype=W_H_F.dtype)
O_Ẋ_W = jnp.zeros((6, 6), dtype=W_H_F.dtype)
case VelRepr.Body:
O_X_W = Adjoint.from_rotation_and_translation(
rotation=W_H_F[0:3, 0:3],
translation=W_H_F[0:3, 3],
inverse=True,
)
O_Ẋ_W = -O_X_W @ Cross.vx(W_v_WF)
case VelRepr.Mixed:
O_X_W = Adjoint.from_rotation_and_translation(
rotation=jnp.eye(3, dtype=W_H_F.dtype),
translation=W_H_F[0:3, 3],
inverse=True,
)
FW_v_WF = O_X_W @ W_v_WF
W_v_W_FW = FW_v_WF.at[3:6].set(jnp.zeros_like(FW_v_WF[3:6]))
O_Ẋ_W = -O_X_W @ Cross.vx(W_v_W_FW)
case _:
raise ValueError(output_vel_repr)
O_J̇_WF_I = O_Ẋ_W @ W_J_WL_I
O_J̇_WF_I += O_X_W @ W_J̇_WL_I
return O_J̇_WF_I
================================================
FILE: src/jaxsim/api/integrators.py
================================================
import dataclasses
from collections.abc import Callable
import jax
import jax.numpy as jnp
import jaxsim
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.api.data import JaxSimModelData
from jaxsim.math import Skew
def semi_implicit_euler_integration(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
link_forces: jtp.Vector,
joint_torques: jtp.Vector,
) -> JaxSimModelData:
"""Integrate the system state using the semi-implicit Euler method."""
with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
# Compute the system acceleration
W_v̇_WB, s̈, contact_state_derivative = js.ode.system_acceleration(
model=model,
data=data,
link_forces=link_forces,
joint_torques=joint_torques,
)
dt = model.time_step
# Compute the new generalized velocity.
new_generalized_acceleration = jnp.hstack([W_v̇_WB, s̈])
new_generalized_velocity = (
data.generalized_velocity + dt * new_generalized_acceleration
)
# Extract the new base and joint velocities.
W_v_B = new_generalized_velocity[0:6]
ṡ = new_generalized_velocity[6:]
# Compute the new base position and orientation.
W_ω_WB = new_generalized_velocity[3:6]
# To obtain the derivative of the base position, we need to subtract
# the skew-symmetric matrix of the base angular velocity times the base position.
# See: S. Traversaro and A. Saccon, “Multibody Dynamics Notation (Version 2), pg.9
W_ṗ_B = new_generalized_velocity[0:3] + Skew.wedge(W_ω_WB) @ data.base_position
W_Q̇_B = jaxsim.math.Quaternion.derivative(
quaternion=data.base_orientation,
omega=W_ω_WB,
omega_in_body_fixed=False,
).squeeze()
W_p_B = data.base_position + dt * W_ṗ_B
W_Q_B = data.base_orientation + dt * W_Q̇_B
base_quaternion_norm = jaxsim.math.safe_norm(W_Q_B, axis=-1)
W_Q_B = W_Q_B / jnp.where(base_quaternion_norm == 0, 1.0, base_quaternion_norm)
s = data.joint_positions + dt * ṡ
integrated_contact_state = jax.tree.map(
lambda x, x_dot: x + dt * x_dot,
data.contact_state,
contact_state_derivative,
)
data = dataclasses.replace(
data,
_base_quaternion=W_Q_B,
_base_position=W_p_B,
_joint_positions=s,
_joint_velocities=ṡ,
_base_linear_velocity=W_v_B[0:3],
_base_angular_velocity=W_ω_WB,
contact_state=integrated_contact_state,
)
# Recompute kinematic caches for the new state.
data = data.replace(model=model)
return data
def rk4_integration(
model: js.model.JaxSimModel,
data: JaxSimModelData,
link_forces: jtp.Vector,
joint_torques: jtp.Vector,
) -> JaxSimModelData:
"""Integrate the system state using the Runge-Kutta 4 method."""
dt = model.time_step
def f(x) -> dict[str, jtp.Matrix]:
with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
data_ti = data.replace(model=model, **x)
return js.ode.system_dynamics(
model=model,
data=data_ti,
link_forces=link_forces,
joint_torques=joint_torques,
)
base_quaternion_norm = jaxsim.math.safe_norm(data._base_quaternion, axis=-1)
base_quaternion = data._base_quaternion / jnp.where(
base_quaternion_norm == 0, 1.0, base_quaternion_norm
)
x_t0 = dict(
base_position=data._base_position,
base_quaternion=base_quaternion,
joint_positions=data._joint_positions,
base_linear_velocity=data._base_linear_velocity,
base_angular_velocity=data._base_angular_velocity,
joint_velocities=data._joint_velocities,
contact_state=data.contact_state,
)
euler_mid = lambda x, dxdt: x + (0.5 * dt) * dxdt
euler_fin = lambda x, dxdt: x + dt * dxdt
k1 = f(x_t0)
k2 = f(jax.tree.map(euler_mid, x_t0, k1))
k3 = f(jax.tree.map(euler_mid, x_t0, k2))
k4 = f(jax.tree.map(euler_fin, x_t0, k3))
# Average the slopes and compute the RK4 state derivative.
average = lambda k1, k2, k3, k4: (k1 + 2 * k2 + 2 * k3 + k4) / 6
dxdt = jax.tree.map(average, k1, k2, k3, k4)
# Integrate the dynamics
x_tf = jax.tree.map(euler_fin, x_t0, dxdt)
data_tf = dataclasses.replace(
data,
_base_position=x_tf["base_position"],
_base_quaternion=x_tf["base_quaternion"],
_joint_positions=x_tf["joint_positions"],
_base_linear_velocity=x_tf["base_linear_velocity"],
_base_angular_velocity=x_tf["base_angular_velocity"],
_joint_velocities=x_tf["joint_velocities"],
contact_state=x_tf["contact_state"],
)
return data_tf.replace(model=model)
def rk4fast_integration(
model: js.model.JaxSimModel,
data: JaxSimModelData,
link_forces: jtp.Vector,
joint_torques: jtp.Vector,
) -> JaxSimModelData:
"""
Integrate the system state using the Runge-Kutta 4 fast method.
Note:
This method is a faster version of the RK4 method, but it may not be as accurate.
It computes the contact forces only once at the beginning of the integration step.
"""
dt = model.time_step
if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
# Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
# with the terrain.
W_f_L_terrain, contact_state_derivative = js.contact.link_contact_forces(
model=model,
data=data,
link_forces=link_forces,
joint_torques=joint_torques,
)
W_f_L_total = link_forces + W_f_L_terrain
# Update the contact state data. This is necessary only for the contact models
# that require propagation and integration of contact state.
contact_state = model.contact_model.update_contact_state(contact_state_derivative)
def f(x) -> dict[str, jtp.Matrix]:
with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
data_ti = data.replace(model=model, **x)
W_v̇_WB, s̈ = js.model.forward_dynamics_aba(
model=model,
data=data_ti,
joint_forces=joint_torques,
link_forces=W_f_L_total,
)
W_ṗ_B, W_Q̇_B, ṡ = js.ode.system_position_dynamics(
data=data,
baumgarte_quaternion_regularization=1.0,
)
return dict(
base_position=W_ṗ_B,
base_quaternion=W_Q̇_B,
joint_positions=ṡ,
base_linear_velocity=W_v̇_WB[0:3],
base_angular_velocity=W_v̇_WB[3:6],
joint_velocities=s̈,
# The contact state is not updated here, as it is assumed to be constant.
contact_state=data_ti.contact_state,
)
base_quaternion_norm = jaxsim.math.safe_norm(data._base_quaternion, axis=-1)
base_quaternion = data._base_quaternion / jnp.where(
base_quaternion_norm == 0, 1.0, base_quaternion_norm
)
x_t0 = dict(
base_position=data._base_position,
base_quaternion=base_quaternion,
joint_positions=data._joint_positions,
base_linear_velocity=data._base_linear_velocity,
base_angular_velocity=data._base_angular_velocity,
joint_velocities=data._joint_velocities,
contact_state=contact_state,
)
euler_mid = lambda x, dxdt: x + (0.5 * dt) * dxdt
euler_fin = lambda x, dxdt: x + dt * dxdt
k1 = f(x_t0)
k2 = f(jax.tree.map(euler_mid, x_t0, k1))
k3 = f(jax.tree.map(euler_mid, x_t0, k2))
k4 = f(jax.tree.map(euler_fin, x_t0, k3))
# Average the slopes and compute the RK4 state derivative.
average = lambda k1, k2, k3, k4: (k1 + 2 * k2 + 2 * k3 + k4) / 6
dxdt = jax.tree.map(average, k1, k2, k3, k4)
# Integrate the dynamics
x_tf = jax.tree.map(euler_fin, x_t0, dxdt)
data_tf = dataclasses.replace(
data,
_base_position=x_tf["base_position"],
_base_quaternion=x_tf["base_quaternion"],
_joint_positions=x_tf["joint_positions"],
_base_linear_velocity=x_tf["base_linear_velocity"],
_base_angular_velocity=x_tf["base_angular_velocity"],
_joint_velocities=x_tf["joint_velocities"],
contact_state=x_tf["contact_state"],
)
return data_tf.replace(model=model)
_INTEGRATORS_MAP: dict[
js.model.IntegratorType, Callable[..., js.data.JaxSimModelData]
] = {
js.model.IntegratorType.SemiImplicitEuler: semi_implicit_euler_integration,
js.model.IntegratorType.RungeKutta4: rk4_integration,
js.model.IntegratorType.RungeKutta4Fast: rk4fast_integration,
}
================================================
FILE: src/jaxsim/api/joint.py
================================================
import functools
from collections.abc import Sequence
import jax
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim import exceptions
# =======================
# Index-related functions
# =======================
@functools.partial(jax.jit, static_argnames="joint_name")
@js.common.named_scope
def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int:
"""
Convert the name of a joint to its index.
Args:
model: The model to consider.
joint_name: The name of the joint.
Returns:
The index of the joint.
"""
if joint_name not in model.joint_names():
raise ValueError(f"Joint '{joint_name}' not found in the model.")
# Note: the index of the joint for RBDAs starts from 1, but the index for
# accessing the right element starts from 0. Therefore, there is a -1.
return (
jnp.array(
model.kin_dyn_parameters.joint_model.joint_names.index(joint_name) - 1
)
.astype(int)
.squeeze()
)
def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:
"""
Convert the index of a joint to its name.
Args:
model: The model to consider.
joint_index: The index of the joint.
Returns:
The name of the joint.
"""
exceptions.raise_value_error_if(
condition=joint_index < 0,
msg="Invalid joint index '{idx}'",
idx=joint_index,
)
return model.kin_dyn_parameters.joint_model.joint_names[joint_index + 1]
@functools.partial(jax.jit, static_argnames="joint_names")
@js.common.named_scope
def names_to_idxs(
model: js.model.JaxSimModel, *, joint_names: Sequence[str]
) -> jax.Array:
"""
Convert a sequence of joint names to their corresponding indices.
Args:
model: The model to consider.
joint_names: The names of the joints.
Returns:
The indices of the joints.
"""
return jnp.array(
[name_to_idx(model=model, joint_name=name) for name in joint_names],
).astype(int)
def idxs_to_names(
model: js.model.JaxSimModel,
*,
joint_indices: Sequence[jtp.IntLike] | jtp.VectorLike,
) -> tuple[str, ...]:
"""
Convert a sequence of joint indices to their corresponding names.
Args:
model: The model to consider.
joint_indices: The indices of the joints.
Returns:
The names of the joints.
"""
return tuple(idx_to_name(model=model, joint_index=idx) for idx in joint_indices)
# ============
# Joint limits
# ============
@jax.jit
def position_limit(
model: js.model.JaxSimModel, *, joint_index: jtp.IntLike
) -> tuple[jtp.Float, jtp.Float]:
"""
Get the position limits of a joint.
Args:
model: The model to consider.
joint_index: The index of the joint.
Returns:
The position limits of the joint.
"""
if model.number_of_joints() == 0:
return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
exceptions.raise_value_error_if(
condition=jnp.array(
[joint_index < 0, joint_index >= model.number_of_joints()]
).any(),
msg="Invalid joint index '{idx}'",
idx=joint_index,
)
s_min = jnp.atleast_1d(
model.kin_dyn_parameters.joint_parameters.position_limits_min
)[joint_index]
s_max = jnp.atleast_1d(
model.kin_dyn_parameters.joint_parameters.position_limits_max
)[joint_index]
return s_min.astype(float), s_max.astype(float)
@functools.partial(jax.jit, static_argnames=["joint_names"])
@js.common.named_scope
def position_limits(
model: js.model.JaxSimModel, *, joint_names: Sequence[str] | None = None
) -> tuple[jtp.Vector, jtp.Vector]:
"""
Get the position limits of a list of joint.
Args:
model: The model to consider.
joint_names: The names of the joints.
Returns:
The position limits of the joints.
"""
joint_idxs = (
names_to_idxs(joint_names=joint_names, model=model)
if joint_names is not None
else jnp.arange(model.number_of_joints())
)
if len(joint_idxs) == 0:
return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min[joint_idxs]
s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max[joint_idxs]
return s_min.astype(float), s_max.astype(float)
# ======================
# Random data generation
# ======================
@functools.partial(jax.jit, static_argnames=["joint_names"])
@js.common.named_scope
def random_joint_positions(
model: js.model.JaxSimModel,
*,
joint_names: Sequence[str] | None = None,
key: jax.Array | None = None,
) -> jtp.Vector:
"""
Generate random joint positions.
Args:
model: The model to consider.
joint_names: The names of the considered joints (all if None).
key: The random key (initialized from seed 0 if None).
Note:
If the joint range or revolute joints is larger than 2π, their joint positions
will be sampled from an interval of size 2π.
Returns:
The random joint positions.
"""
# Consider the key corresponding to a zero seed if it was not passed.
key = key if key is not None else jax.random.PRNGKey(seed=0)
# Get the joint limits parsed from the model description.
s_min, s_max = position_limits(model=model, joint_names=joint_names)
# Get the joint indices.
# Note that it will trigger an exception if the given `joint_names` are not valid.
joint_names = joint_names if joint_names is not None else model.joint_names()
joint_indices = (
names_to_idxs(model=model, joint_names=joint_names)
if joint_names is not None
else jnp.arange(model.number_of_joints())
)
from jaxsim.parsers.descriptions.joint import JointType
# Filter for revolute joints.
is_revolute = jnp.where(
jnp.array(model.kin_dyn_parameters.joint_model.joint_types[1:])[joint_indices]
== JointType.Revolute,
True,
False,
)
# Shorthand for π.
π = jnp.pi
# Filter for revolute with full range (or continuous).
is_revolute_full_range = jnp.logical_and(is_revolute, s_max - s_min >= 2 * π)
# Clip the lower limit to -π if the joint range is larger than [-π, π].
s_min = jnp.where(
jnp.logical_and(
is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
),
-π,
s_min,
)
# Clip the upper limit to +π if the joint range is larger than [-π, π].
s_max = jnp.where(
jnp.logical_and(
is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
),
π,
s_max,
)
# Shift the lower limit if the upper limit is smaller than +π.
s_min = jnp.where(
jnp.logical_and(is_revolute_full_range, s_max < π),
s_max - 2 * π,
s_min,
)
# Shift the upper limit if the lower limit is larger than -π.
s_max = jnp.where(
jnp.logical_and(is_revolute_full_range, s_min > -π),
s_min + 2 * π,
s_max,
)
# Sample the joint positions.
s_random = jax.random.uniform(
minval=s_min,
maxval=s_max,
key=key,
shape=s_min.shape,
)
return s_random
================================================
FILE: src/jaxsim/api/kin_dyn_parameters.py
================================================
from __future__ import annotations
import dataclasses
from itertools import starmap
from typing import ClassVar
import jax.lax
import jax.numpy as jnp
import jax_dataclasses
import numpy as np
import numpy.typing as npt
from jax_dataclasses import Static
import jaxsim
import jaxsim.typing as jtp
from jaxsim.math import Inertia, JointModel, supported_joint_motion
from jaxsim.math.adjoint import Adjoint
from jaxsim.parsers.descriptions import JointDescription, JointType, ModelDescription
from jaxsim.utils import HashedNumpyArray, JaxsimDataclass
@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
class KinDynParameters(JaxsimDataclass):
r"""
Class storing the kinematic and dynamic parameters of a model.
Attributes:
link_names: The names of the links.
parent_array: The parent array :math:`\lambda(i)` of the model.
support_body_array_bool:
The boolean support parent array :math:`\kappa_{b}(i)` of the model.
link_parameters: The parameters of the links.
frame_parameters: The parameters of the frames.
contact_parameters: The parameters of the collidable points.
joint_model: The joint model of the model.
joint_parameters: The parameters of the joints.
hw_link_metadata: The hardware parameters of the model links.
constraints: The kinematic constraints of the model. They can be used only with Relaxed-Rigid contact model.
"""
# Static
link_names: Static[tuple[str]]
_parent_array: Static[HashedNumpyArray]
_support_body_array_bool: Static[HashedNumpyArray]
_motion_subspaces: Static[HashedNumpyArray]
# Tree level structure for parallel algorithms.
# level_nodes: (n_levels, max_width) array of link indices at each depth level,
# padded with 0 for levels with fewer nodes than max_width.
# level_mask: (n_levels, max_width) boolean mask, True for real nodes.
_level_nodes: Static[HashedNumpyArray]
_level_mask: Static[HashedNumpyArray]
# Links
link_parameters: LinkParameters
# Contacts
contact_parameters: ContactParameters
# Frames
frame_parameters: FrameParameters
# Joints
joint_model: JointModel
joint_parameters: JointParameters | None
# Model hardware parameters
hw_link_metadata: HwLinkMetadata | None = dataclasses.field(default=None)
# Kinematic constraints
constraints: ConstraintMap | None = dataclasses.field(default=None)
@property
def motion_subspaces(self) -> jtp.Matrix:
r"""
Return the motion subspaces :math:`\mathbf{S}(s)` of the joints.
"""
return self._motion_subspaces.get()
@property
def parent_array(self) -> jtp.Vector:
r"""
Return the parent array :math:`\lambda(i)` of the model.
"""
return self._parent_array.get()
@property
def support_body_array_bool(self) -> jtp.Matrix:
r"""
Return the boolean support parent array :math:`\kappa_{b}(i)` of the model.
"""
return self._support_body_array_bool.get()
@property
def level_nodes(self) -> jtp.Matrix:
r"""
Return the tree level nodes array of shape ``(n_levels, max_width)``.
Each row contains the link indices at the corresponding depth level,
padded with 0 for levels with fewer nodes than ``max_width``.
"""
return self._level_nodes.get()
@property
def level_mask(self) -> jtp.Matrix:
r"""
Return the tree level mask of shape ``(n_levels, max_width)``.
Each entry is ``True`` for real nodes and ``False`` for padding.
"""
return self._level_mask.get()
@staticmethod
def _compute_tree_levels(
parent_array: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""
Compute the tree level decomposition from a parent array.
Args:
parent_array: Array of shape ``(n,)`` where ``parent_array[i]``
is the parent of link ``i``. ``parent_array[0] == -1`` for the root.
Returns:
A tuple ``(level_nodes, level_mask)`` where:
- ``level_nodes`` has shape ``(n_levels, max_width)`` with link
indices at each depth level (padded with 0).
- ``level_mask`` has shape ``(n_levels, max_width)`` with ``True``
for real nodes.
"""
import numpy as np
n = len(parent_array)
# Compute depth of each node.
depth = np.zeros(n, dtype=int)
for i in range(1, n):
depth[i] = depth[parent_array[i]] + 1
max_depth = int(depth.max()) if n > 0 else 0
n_levels = max_depth + 1
# Group nodes by depth level.
levels: list[list[int]] = [[] for _ in range(n_levels)]
for i in range(n):
levels[depth[i]].append(i)
max_width = max(len(lev) for lev in levels) if levels else 1
# Build padded arrays.
level_nodes = np.zeros((n_levels, max_width), dtype=int)
level_mask = np.zeros((n_levels, max_width), dtype=bool)
for d, lev in enumerate(levels):
for j, node_idx in enumerate(lev):
level_nodes[d, j] = node_idx
level_mask[d, j] = True
return level_nodes, level_mask
@staticmethod
def build(
model_description: ModelDescription, constraints: ConstraintMap | None
) -> KinDynParameters:
"""
Construct the kinematic and dynamic parameters of the model.
Args:
model_description: The parsed model description to consider.
constraints: An object of type ConstraintMap specifying the kinematic constraint of the model.
Returns:
The kinematic and dynamic parameters of the model.
Note:
This class is meant to ease the management of parametric models in
an automatic differentiation context.
"""
# Extract the links ordered by their index.
# The link index corresponds to the body index ∈ [0, num_bodies - 1].
ordered_links = sorted(
list(model_description.links_dict.values()),
key=lambda l: l.index,
)
# Extract the joints ordered by their index.
# The joint index matches the index of its child link, therefore it starts
# from 1. Keep this in mind since this 1-indexing might introduce bugs.
ordered_joints = sorted(
list(model_description.joints_dict.values()),
key=lambda j: j.index,
)
# ================
# Links properties
# ================
# Create a list of link parameters objects.
link_parameters_list = [
LinkParameters.build_from_spatial_inertia(index=link.index, M=link.inertia)
for link in ordered_links
]
# Create a vectorized object of link parameters.
link_parameters = jax.tree.map(lambda *l: jnp.stack(l), *link_parameters_list)
# =================
# Joints properties
# =================
# Create a list of joint parameters objects.
joint_parameters_list = [
JointParameters.build_from_joint_description(joint_description=joint)
for joint in ordered_joints
]
# Create a vectorized object of joint parameters.
joint_parameters = (
jax.tree.map(lambda *l: jnp.stack(l), *joint_parameters_list)
if ordered_joints
else JointParameters(
index=jnp.array([], dtype=int),
friction_static=jnp.array([], dtype=float),
friction_viscous=jnp.array([], dtype=float),
position_limits_min=jnp.array([], dtype=float),
position_limits_max=jnp.array([], dtype=float),
position_limit_spring=jnp.array([], dtype=float),
position_limit_damper=jnp.array([], dtype=float),
)
)
# Create an object that defines the joint model (parent-to-child transforms).
joint_model = JointModel.build(description=model_description)
# ===================
# Contacts properties
# ===================
# Create the object storing the parameters of collidable points.
# Note that, contrarily to LinkParameters and JointsParameters, this object
# is not created with vmap. This is because the "body" attribute of the object
# must be Static for JIT-related reasons, and tree_map would not consider it
# as a leaf.
contact_parameters = ContactParameters.build_from(
model_description=model_description
)
# =================
# Frames properties
# =================
# Create the object storing the parameters of frames.
# Note that, contrarily to LinkParameters and JointsParameters, this object
# is not created with vmap. This is because the "name" attribute of the object
# must be Static for JIT-related reasons, and tree_map would not consider it
# as a leaf.
frame_parameters = FrameParameters.build_from(
model_description=model_description
)
# ===============
# Tree properties
# ===============
# Build the parent array λ(i) of the model.
# Note: the parent of the base link is not set since it's not defined.
parent_array_dict = {
link.index: model_description.links_dict[link.parent_name].index
for link in ordered_links
if link.parent_name is not None
}
parent_array = jnp.array([-1, *list(parent_array_dict.values())], dtype=int)
# Instead of building the support parent array κ(i) for each link of the model,
# that has a variable length depending on the number of links connecting the
# root to the i-th link, we build the corresponding boolean version.
# Given a link index i, the boolean support parent array κb(i) is an array
# with the same number of elements of λ(i) having the i-th element set to True
# if the i-th link is in the support parent array κ(i), False otherwise.
# We store the boolean κb(i) as static attribute of the PyTree so that
# algorithms that need to access it can be jit-compiled.
def κb(link_index: jtp.IntLike) -> jtp.Vector:
κb = jnp.zeros(len(ordered_links), dtype=bool)
carry0 = κb, link_index
def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]:
κb, active_link_index = carry
κb, active_link_index = jax.lax.cond(
pred=(i == active_link_index),
false_fun=lambda: (κb, active_link_index),
true_fun=lambda: (
κb.at[active_link_index].set(True),
parent_array[active_link_index],
),
)
return (κb, active_link_index), None
(κb, _), _ = jax.lax.scan(
f=scan_body,
init=carry0,
xs=jnp.flip(jnp.arange(start=0, stop=len(ordered_links))),
)
return κb
support_body_array_bool = jax.vmap(κb)(
jnp.arange(start=0, stop=len(ordered_links))
)
def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike:
S = {
JointType.Fixed: np.zeros(shape=(6, 1)),
JointType.Revolute: np.vstack(np.hstack([np.zeros(3), axis.axis])),
JointType.Prismatic: np.vstack(np.hstack([axis.axis, np.zeros(3)])),
}
return S[joint_type]
S_J = jnp.array(
list(
starmap(
motion_subspace,
zip(
joint_model.joint_types[1:], joint_model.joint_axis, strict=True
),
)
)
if len(joint_model.joint_axis) != 0
else jnp.empty((0, 6, 1))
)
motion_subspaces = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J])
# ====================
# Tree level structure
# ====================
parent_array_np = np.array([-1, *list(parent_array_dict.values())], dtype=int)
level_nodes, level_mask = KinDynParameters._compute_tree_levels(parent_array_np)
# ===========
# Constraints
# ===========
constraints = ConstraintMap() if constraints is None else constraints
# =================================
# Build and return KinDynParameters
# =================================
return KinDynParameters(
link_names=tuple(l.name for l in ordered_links),
_parent_array=HashedNumpyArray(array=parent_array),
_support_body_array_bool=HashedNumpyArray(array=support_body_array_bool),
_motion_subspaces=HashedNumpyArray(array=motion_subspaces),
_level_nodes=HashedNumpyArray(array=level_nodes),
_level_mask=HashedNumpyArray(array=level_mask),
link_parameters=link_parameters,
joint_model=joint_model,
joint_parameters=joint_parameters,
contact_parameters=contact_parameters,
frame_parameters=frame_parameters,
constraints=constraints,
)
def __eq__(self, other: KinDynParameters) -> bool:
if not isinstance(other, KinDynParameters):
return False
return hash(self) == hash(other)
def __hash__(self) -> int:
return hash(
(
hash(self.number_of_links()),
hash(self.number_of_joints()),
hash(self.frame_parameters.name),
hash(self.frame_parameters.body),
hash(self._parent_array),
hash(self._support_body_array_bool),
)
)
# =============================
# Helpers to extract parameters
# =============================
def number_of_links(self) -> int:
"""
Return the number of links of the model.
Returns:
The number of links of the model.
"""
return len(self.link_names)
def number_of_joints(self) -> int:
"""
Return the number of joints of the model.
Returns:
The number of joints of the model.
"""
return len(self.joint_model.joint_names) - 1
def number_of_frames(self) -> int:
"""
Return the number of frames of the model.
Returns:
The number of frames of the model.
"""
return len(self.frame_parameters.name)
def support_body_array(self, link_index: jtp.IntLike) -> jtp.Vector:
r"""
Return the support parent array :math:`\kappa(i)` of a link.
Args:
link_index: The index of the link.
Returns:
The support parent array :math:`\kappa(i)` of the link.
Note:
This method returns a variable-length vector. In jit-compiled functions,
it's better to use the (static) boolean version `support_body_array_bool`.
"""
return jnp.array(
jnp.where(self.support_body_array_bool[link_index])[0], dtype=int
)
# ========================
# Quantities used by RBDAs
# ========================
@jax.jit
def links_spatial_inertia(self) -> jtp.Array:
"""
Return the spatial inertia of all links of the model.
Returns:
The spatial inertia of all links of the model.
"""
return jax.vmap(LinkParameters.spatial_inertia)(self.link_parameters)
@jax.jit
def tree_transforms(self) -> jtp.Array:
r"""
Return the tree transforms of the model.
Returns:
The transforms
:math:`{}^{\text{pre}(i)} H_{\lambda(i)}`
of all joints of the model.
"""
pre_Xi_λ = jax.vmap(
lambda i: self.joint_model.parent_H_predecessor(joint_index=i)
.inverse()
.adjoint()
)(jnp.arange(1, self.number_of_joints() + 1))
return jnp.vstack(
[
jnp.zeros(shape=(1, 6, 6), dtype=float),
pre_Xi_λ,
]
)
@jax.jit
def joint_transforms(
self, joint_positions: jtp.VectorLike, base_transform: jtp.MatrixLike
) -> jtp.Array:
r"""
Return the transforms of the joints.
Args:
joint_positions: The joint positions.
base_transform: The homogeneous matrix defining the base pose.
Returns:
The stacked transforms
:math:`{}^{i} \mathbf{H}_{\lambda(i)}(s)`
of each joint.
"""
# Rename the base transform.
W_H_B = base_transform
# Extract the parent-to-predecessor fixed transforms of the joints.
λ_H_pre = jnp.vstack(
[
jnp.eye(4)[jnp.newaxis],
self.joint_model.λ_H_pre[1 : 1 + self.number_of_joints()],
]
)
if self.number_of_joints() == 0:
pre_H_suc_J = jnp.empty((0, 4, 4))
else:
pre_H_suc_J = jax.vmap(supported_joint_motion)(
joint_types=jnp.array(self.joint_model.joint_types[1:]).astype(int),
joint_positions=jnp.array(joint_positions),
joint_axes=jnp.array([j.axis for j in self.joint_model.joint_axis]),
)
# Extract the transforms and motion subspaces of the joints.
# We stack the base transform W_H_B at index 0, and a dummy motion subspace
# for either the fixed or free-floating joint connecting the world to the base.
pre_H_suc = jnp.vstack([W_H_B[jnp.newaxis, ...], pre_H_suc_J])
# Extract the successor-to-child fixed transforms.
# Note that here we include also the index 0 since suc_H_child[0] stores the
# optional pose of the base link w.r.t. the root frame of the model.
# This is supported by SDF when the base link element is defined.
suc_H_i = self.joint_model.suc_H_i[jnp.arange(0, 1 + self.number_of_joints())]
# Compute the overall transforms from the parent to the child of each joint by
# composing all the components of our joint model.
i_X_λ = jax.vmap(
lambda λ_Hi_pre, pre_Hi_suc, suc_Hi_i: Adjoint.from_transform(
transform=λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i, inverse=True
)
)(λ_H_pre, pre_H_suc, suc_H_i)
return i_X_λ
# ============================
# Helpers to update parameters
# ============================
def set_link_mass(
self, link_index: jtp.IntLike, mass: jtp.FloatLike
) -> KinDynParameters:
"""
Set the mass of a link.
Args:
link_index: The index of the link.
mass: The mass of the link.
Returns:
The updated kinematic and dynamic parameters of the model.
"""
link_parameters = self.link_parameters.replace(
mass=self.link_parameters.mass.at[link_index].set(mass)
)
return self.replace(link_parameters=link_parameters)
def set_link_inertia(
self, link_index: jtp.IntLike, inertia: jtp.MatrixLike
) -> KinDynParameters:
r"""
Set the inertia tensor of a link.
Args:
link_index: The index of the link.
inertia: The :math:`3 \times 3` inertia tensor of the link.
Returns:
The updated kinematic and dynamic parameters of the model.
"""
inertia_elements = LinkParameters.flatten_inertia_tensor(I=inertia)
link_parameters = self.link_parameters.replace(
inertia_elements=self.link_parameters.inertia_elements.at[link_index].set(
inertia_elements
)
)
return self.replace(link_parameters=link_parameters)
@jax_dataclasses.pytree_dataclass
class JointParameters(JaxsimDataclass):
"""
Class storing the parameters of a joint.
Attributes:
index: The index of the joint.
friction_static: The static friction of the joint.
friction_viscous: The viscous friction of the joint.
position_limits_min: The lower position limit of the joint.
position_limits_max: The upper position limit of the joint.
position_limit_spring: The spring constant of the position limit.
position_limit_damper: The damper constant of the position limit.
Note:
This class is used inside KinDynParameters to store the vectorized set
of joint parameters.
"""
index: jtp.Int
friction_static: jtp.Float
friction_viscous: jtp.Float
position_limits_min: jtp.Float
position_limits_max: jtp.Float
position_limit_spring: jtp.Float
position_limit_damper: jtp.Float
@staticmethod
def build_from_joint_description(
joint_description: JointDescription,
) -> JointParameters:
"""
Build a JointParameters object from a joint description.
Args:
joint_description: The joint description to consider.
Returns:
The JointParameters object.
"""
s_min = joint_description.position_limit[0]
s_max = joint_description.position_limit[1]
position_limits_min = jnp.minimum(s_min, s_max)
position_limits_max = jnp.maximum(s_min, s_max)
friction_static = jnp.array(joint_description.friction_static).squeeze()
friction_viscous = jnp.array(joint_description.friction_viscous).squeeze()
position_limit_spring = jnp.array(
joint_description.position_limit_spring
).squeeze()
position_limit_damper = jnp.array(
joint_description.position_limit_damper
).squeeze()
return JointParameters(
index=jnp.array(joint_description.index).squeeze().astype(int),
friction_static=friction_static.astype(float),
friction_viscous=friction_viscous.astype(float),
position_limits_min=position_limits_min.astype(float),
position_limits_max=position_limits_max.astype(float),
position_limit_spring=position_limit_spring.astype(float),
position_limit_damper=position_limit_damper.astype(float),
)
@jax_dataclasses.pytree_dataclass
class LinkParameters(JaxsimDataclass):
r"""
Class storing the parameters of a link.
Attributes:
index: The index of the link.
mass: The mass of the link.
inertia_elements:
The unique elements of the :math:`3 \times 3` inertia tensor of the link.
center_of_mass:
The translation :math:`{}^L \mathbf{p}_{\text{CoM}}` between the origin
of the link frame and the link's center of mass, expressed in the
coordinates of the link frame.
Note:
This class is used inside KinDynParameters to store the vectorized set
of link parameters.
"""
index: jtp.Int
mass: jtp.Float
center_of_mass: jtp.Vector
inertia_elements: jtp.Vector
@staticmethod
def build_from_spatial_inertia(index: jtp.IntLike, M: jtp.Matrix) -> LinkParameters:
r"""
Build a LinkParameters object from a :math:`6 \times 6` spatial inertia matrix.
Args:
index: The index of the link.
M: The :math:`6 \times 6` spatial inertia matrix of the link.
Returns:
The LinkParameters object.
"""
# Extract the link parameters from the 6D spatial inertia.
m, L_p_CoM, I_CoM = Inertia.to_params(M=M)
# Extract only the necessary elements of the inertia tensor.
inertia_elements = I_CoM[jnp.triu_indices(3)]
return LinkParameters(
index=jnp.array(index).squeeze().astype(int),
mass=jnp.array(m).squeeze().astype(float),
center_of_mass=jnp.atleast_1d(jnp.array(L_p_CoM).squeeze()).astype(float),
inertia_elements=jnp.atleast_1d(inertia_elements.squeeze()).astype(float),
)
@staticmethod
def build_from_inertial_parameters(
index: jtp.IntLike, m: jtp.FloatLike, I: jtp.MatrixLike, c: jtp.VectorLike
) -> LinkParameters:
r"""
Build a LinkParameters object from the inertial parameters of a link.
Args:
index: The index of the link.
m: The mass of the link.
I: The :math:`3 \times 3` inertia tensor of the link.
c: The translation between the link frame and the link's center of mass.
Returns:
The LinkParameters object.
"""
# Extract only the necessary elements of the inertia tensor.
inertia_elements = I[jnp.triu_indices(3)]
return LinkParameters(
index=jnp.array(index).squeeze().astype(int),
mass=jnp.array(m).squeeze().astype(float),
center_of_mass=jnp.atleast_1d(c.squeeze()).astype(float),
inertia_elements=jnp.atleast_1d(inertia_elements.squeeze()).astype(float),
)
@staticmethod
def build_from_flat_parameters(
index: jtp.IntLike, parameters: jtp.VectorLike
) -> LinkParameters:
"""
Build a LinkParameters object from a flat vector of parameters.
Args:
index: The index of the link.
parameters: The flat vector of parameters.
Returns:
The LinkParameters object.
"""
index = jnp.array(index).squeeze().astype(int)
m = jnp.array(parameters[0]).squeeze().astype(float)
c = jnp.atleast_1d(parameters[1:4].squeeze()).astype(float)
inertia_elements = jnp.atleast_1d(parameters[4:].squeeze()).astype(float)
return LinkParameters(
index=index, mass=m, inertia_elements=inertia_elements, center_of_mass=c
)
@staticmethod
def flat_parameters(params: LinkParameters) -> jtp.Vector:
"""
Return the parameters of a link as a flat vector.
Args:
params: The link parameters.
Returns:
The parameters of the link as a flat vector.
"""
return (
jnp.hstack(
[
params.mass,
params.center_of_mass.squeeze(),
params.inertia_elements,
]
)
.squeeze()
.astype(float)
)
@staticmethod
def inertia_tensor(params: LinkParameters) -> jtp.Matrix:
r"""
Return the :math:`3 \times 3` inertia tensor of a link.
Args:
params: The link parameters.
Returns:
The :math:`3 \times 3` inertia tensor of the link.
"""
return LinkParameters.unflatten_inertia_tensor(
inertia_elements=params.inertia_elements
)
@staticmethod
def spatial_inertia(params: LinkParameters) -> jtp.Matrix:
r"""
Return the :math:`6 \times 6` spatial inertia matrix of a link.
Args:
params: The link parameters.
Returns:
The :math:`6 \times 6` spatial inertia matrix of the link.
"""
return Inertia.to_sixd(
mass=params.mass,
I=LinkParameters.inertia_tensor(params),
com=params.center_of_mass,
)
@staticmethod
def flatten_inertia_tensor(I: jtp.Matrix) -> jtp.Vector:
r"""
Flatten a :math:`3 \times 3` inertia tensor into a vector of unique elements.
Args:
I: The :math:`3 \times 3` inertia tensor.
Returns:
The vector of unique elements of the inertia tensor.
"""
return jnp.atleast_1d(I[jnp.triu_indices(3)].squeeze())
@staticmethod
def unflatten_inertia_tensor(inertia_elements: jtp.Vector) -> jtp.Matrix:
r"""
Unflatten a vector of unique elements into a :math:`3 \times 3` inertia tensor.
Args:
inertia_elements: The vector of unique elements of the inertia tensor.
Returns:
The :math:`3 \times 3` inertia tensor.
"""
I = jnp.zeros([3, 3]).at[jnp.triu_indices(3)].set(inertia_elements.squeeze())
return jnp.atleast_2d(jnp.where(I, I, I.T)).astype(float)
@jax_dataclasses.pytree_dataclass
class ContactParameters(JaxsimDataclass):
"""
Class storing the contact parameters of a model.
Attributes:
body:
A tuple of integers representing, for each collidable point, the index of
the body (link) to which it is rigidly attached to.
point:
The translations between the link frame and the collidable point, expressed
in the coordinates of the parent link frame.
enabled:
A tuple of booleans representing, for each collidable point, whether it is
enabled or not in contact models.
Note:
Contrarily to LinkParameters and JointParameters, this class is not meant
to be created with vmap. This is because the `body` attribute must be `Static`.
"""
body: Static[tuple[int, ...]] = dataclasses.field(default_factory=tuple)
point: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.array([]))
enabled: Static[tuple[bool, ...]] = dataclasses.field(default_factory=tuple)
@property
def indices_of_enabled_collidable_points(self) -> npt.NDArray:
"""
Return the indices of the enabled collidable points.
"""
return np.where(np.array(self.enabled))[0]
@staticmethod
def build_from(model_description: ModelDescription) -> ContactParameters:
"""
Build a ContactParameters object from a model description.
Args:
model_description: The model description to consider.
Returns:
The ContactParameters object.
"""
if len(model_description.collision_shapes) == 0:
return ContactParameters()
# Get all the links so that we can take their updated index.
links_dict = {link.name: link for link in model_description}
# Get all the enabled collidable points of the model.
collidable_points = model_description.all_enabled_collidable_points()
# Extract the positions L_p_C of the collidable points w.r.t. the link frames
# they are rigidly attached to.
points = jnp.vstack([cp.position for cp in collidable_points])
# Extract the indices of the links to which the collidable points are rigidly
# attached to.
link_index_of_points = tuple(
links_dict[cp.parent_link.name].index for cp in collidable_points
)
# Build the ContactParameters object.
cp = ContactParameters(
point=points,
body=link_index_of_points,
enabled=tuple(True for _ in link_index_of_points),
)
assert cp.point.shape[1] == 3, cp.point.shape[1]
assert cp.point.shape[0] == len(cp.body), cp.point.shape[0]
return cp
@jax_dataclasses.pytree_dataclass
class FrameParameters(JaxsimDataclass):
"""
Class storing the frame parameters of a model.
Attributes:
name: A tuple of strings defining the frame names.
body:
A vector of integers representing, for each frame, the index of
the body (link) to which it is rigidly attached to.
transform: The transforms of the frames w.r.t. their parent link.
Note:
Contrarily to LinkParameters and JointParameters, this class is not meant
to be created with vmap. This is because the `name` attribute must be `Static`.
"""
name: Static[tuple[str, ...]] = dataclasses.field(default_factory=tuple)
body: Static[tuple[int, ...]] = dataclasses.field(default_factory=tuple)
transform: jtp.Array = dataclasses.field(default_factory=lambda: jnp.array([]))
@staticmethod
def build_from(model_description: ModelDescription) -> FrameParameters:
"""
Build a FrameParameters object from a model description.
Args:
model_description: The model description to consider.
Returns:
The FrameParameters object.
"""
if len(model_description.frames) == 0:
return FrameParameters()
# Extract the frame names.
names = tuple(frame.name for frame in model_description.frames)
# For each frame, extract the index of the link to which it is attached to.
parent_link_index_of_frames = tuple(
model_description.links_dict[frame.parent_name].index
for frame in model_description.frames
)
# For each frame, extract the transform w.r.t. its parent link.
transforms = jnp.atleast_3d(
jnp.stack([frame.pose for frame in model_description.frames])
)
# Build the FrameParameters object.
fp = FrameParameters(
name=names,
transform=transforms.astype(float),
body=parent_link_index_of_frames,
)
assert fp.transform.shape[1:] == (4, 4), fp.transform.shape[1:]
assert fp.transform.shape[0] == len(fp.body), fp.transform.shape[0]
return fp
@dataclasses.dataclass(frozen=True)
class LinkParametrizableShape:
"""
Enum-like class listing the supported shapes for HW parametrization.
"""
Unsupported: ClassVar[int] = -1
Box: ClassVar[int] = 0
Cylinder: ClassVar[int] = 1
Sphere: ClassVar[int] = 2
Mesh: ClassVar[int] = 3
@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
class HwLinkMetadata(JaxsimDataclass):
"""
Class storing the hardware parameters of a link.
Attributes:
link_shape: The shape of the link.
0 = box, 1 = cylinder, 2 = sphere, 3 = mesh, -1 = unsupported.
geometry: Shape parameters used by HW parametrization.
box: [lx,ly,lz], cylinder: [r,l,0], sphere: [r,0,0],
mesh: cumulative anisotropic scale factors [sx,sy,sz] (initialized to [1,1,1]).
density: The density of the link.
L_H_G: The homogeneous transformation matrix from the link frame to the CoM frame G.
L_H_vis: The homogeneous transformation matrix from the link frame to the visual frame.
L_H_pre_mask: The mask indicating the link's child joint indices.
L_H_pre: The homogeneous transforms for child joints.
mesh_moments: Precomputed volumetric moments for mesh shapes (n_links x 13).
Each row stores [V_ref, com_x, com_y, com_z, Σ_00..Σ_22] where V_ref is the
reference volume, com is the volumetric center of mass, and Σ is the
volumetric covariance matrix at the origin. Zero for non-mesh links.
mesh_vertices: The original centered mesh vertices (Nx3) for mesh shapes, None otherwise.
mesh_faces: The mesh triangle faces (Mx3 integer indices) for mesh shapes, None otherwise.
mesh_offset: The original mesh centroid offset (3D vector) for mesh shapes, None otherwise.
mesh_uri: The path to the mesh file for reference, None otherwise.
"""
link_shape: jtp.Vector
geometry: jtp.Vector
density: jtp.Float
L_H_G: jtp.Matrix
L_H_vis: jtp.Matrix
L_H_pre_mask: jtp.Vector
L_H_pre: jtp.Matrix
mesh_moments: jtp.Matrix
mesh_vertices: Static[tuple[HashedNumpyArray | None, ...] | None]
mesh_faces: Static[tuple[HashedNumpyArray | None, ...] | None]
mesh_offset: Static[tuple[HashedNumpyArray | None, ...] | None]
mesh_uri: Static[tuple[str | None, ...] | None]
@classmethod
def empty(cls) -> HwLinkMetadata:
"""Return hardware metadata representing the absence of links."""
return cls(
link_shape=jnp.array([], dtype=int),
geometry=jnp.array([], dtype=float),
density=jnp.array([], dtype=float),
L_H_G=jnp.array([], dtype=float),
L_H_vis=jnp.array([], dtype=float),
L_H_pre_mask=jnp.array([], dtype=bool),
L_H_pre=jnp.array([], dtype=float),
mesh_moments=jnp.zeros((0, 13), dtype=float),
mesh_vertices=None,
mesh_faces=None,
mesh_offset=None,
mesh_uri=None,
)
@staticmethod
def compute_mesh_inertia(
vertices: jtp.Matrix, faces: jtp.Matrix, density: jtp.Float
) -> tuple[jtp.Float, jtp.Vector, jtp.Matrix]:
"""
Compute mass, center of mass, and inertia tensor from mesh geometry.
Uses the divergence theorem to compute volumetric properties by integrating
over tetrahedra formed between the mesh surface and the origin.
Args:
vertices: Mesh vertices (Nx3) in the link frame, should be centered.
faces: Triangle face indices (Mx3), integer indices into vertices array.
density: Material density.
Returns:
A tuple containing the computed mass, the CoM position and the 3x3
inertia tensor at the CoM.
"""
# Extract triangles from vertices using face indices
triangles = vertices[faces.astype(int)]
A, B, C = triangles[:, 0], triangles[:, 1], triangles[:, 2]
# Compute signed volume of tetrahedra relative to origin
# vol = 1/6 * (A . (B x C))
tetrahedron_volumes = jnp.sum(A * jnp.cross(B, C), axis=1) / 6.0
total_signed_volume = jnp.sum(tetrahedron_volumes)
# Normalize the global winding sign so positive density yields non-negative mass.
orientation_sign = jnp.where(total_signed_volume < 0, -1.0, 1.0)
tetrahedron_volumes = tetrahedron_volumes * orientation_sign
total_volume = jnp.sum(tetrahedron_volumes)
eps = jnp.asarray(1e-12, dtype=total_volume.dtype)
is_valid_volume = jnp.abs(total_volume) > eps
safe_total_volume = jnp.where(is_valid_volume, total_volume, 1.0)
mass = jnp.where(is_valid_volume, total_volume * density, 0.0)
# Compute center of mass
tet_coms = (A + B + C) / 4.0
com_position = jnp.where(
is_valid_volume,
jnp.sum(tet_coms * tetrahedron_volumes[:, None], axis=0)
/ safe_total_volume,
jnp.zeros(3, dtype=vertices.dtype),
)
# Compute inertia tensor with covariance approach
def compute_tetrahedron_covariance(a, b, c, vol):
s = a + b + c
return (vol / 20.0) * (
jnp.outer(a, a) + jnp.outer(b, b) + jnp.outer(c, c) + jnp.outer(s, s)
)
covariance_matrices = jax.vmap(compute_tetrahedron_covariance)(
A, B, C, tetrahedron_volumes
)
Σ_origin = jnp.sum(covariance_matrices, axis=0)
# Shift to CoM using parallel axis theorem
Σ_com = Σ_origin * density - mass * jnp.outer(com_position, com_position)
# Convert covariance to inertia tensor
I_com = jnp.trace(Σ_com) * jnp.eye(3, dtype=vertices.dtype) - Σ_com
I_com = jnp.where(
is_valid_volume, I_com, jnp.zeros((3, 3), dtype=vertices.dtype)
)
return mass, com_position, I_com
@staticmethod
def precompute_mesh_moments(vertices: np.ndarray, faces: np.ndarray) -> np.ndarray:
"""
Precompute volumetric moments from reference mesh geometry.
Computes the reference volume, center of mass, and volumetric covariance
matrix at the origin using numpy. These 13 scalars are sufficient to
analytically reconstruct mass and inertia under any anisotropic scaling,
avoiding the need to embed full mesh arrays in JIT-compiled programs.
Args:
vertices: Mesh vertices (Nx3), should be centered.
faces: Triangle face indices (Mx3).
Returns:
A 13-element array: [V_ref, com_x, com_y, com_z, Σ_00..Σ_22].
"""
triangles = vertices[faces.astype(int)]
A, B, C = triangles[:, 0], triangles[:, 1], triangles[:, 2]
volumes = np.sum(A * np.cross(B, C), axis=1) / 6.0
total_signed = np.sum(volumes)
sign = np.sign(total_signed) if abs(total_signed) > 1e-12 else 1.0
volumes = volumes * sign
V_ref = np.sum(volumes)
if abs(V_ref) < 1e-12:
return np.zeros(13, dtype=np.float64)
# Center of mass
com = np.sum(volumes[:, None] * (A + B + C) / 4.0, axis=0) / V_ref
# Volumetric covariance at origin (same formula as compute_mesh_inertia)
S = A + B + C
cov = (volumes[:, None, None] / 20.0) * (
A[:, :, None] * A[:, None, :]
+ B[:, :, None] * B[:, None, :]
+ C[:, :, None] * C[:, None, :]
+ S[:, :, None] * S[:, None, :]
)
Sigma = np.sum(cov, axis=0)
return np.concatenate([[V_ref], com, Sigma.flatten()])
@staticmethod
def compute_mesh_inertia_from_moments(
moments: jtp.Vector, dims: jtp.Vector, density: jtp.Float
) -> tuple[jtp.Float, jtp.Matrix]:
"""
Compute mass and inertia tensor from precomputed volumetric moments.
Uses analytical scaling laws to derive physical properties under
anisotropic scaling without requiring the full mesh geometry.
Under scaling S = diag(sx, sy, sz):
- V' = det(S) * V_ref
- com' = S @ com_ref
- Σ_origin' = det(S) * S @ Σ_ref @ S
Args:
moments: Precomputed moments array of length 13.
dims: Current anisotropic scale factors [sx, sy, sz].
density: Current material density.
Returns:
A tuple of (mass, inertia_at_com).
"""
V_ref = moments[0]
com_ref = moments[1:4]
Sigma_ref = moments[4:13].reshape(3, 3)
det_s = dims[0] * dims[1] * dims[2]
S = jnp.diag(dims)
mass = density * V_ref * det_s
com = dims * com_ref
Sigma_scaled = det_s * (S @ Sigma_ref @ S)
Sigma_com = density * Sigma_scaled - mass * jnp.outer(com, com)
I_com = jnp.trace(Sigma_com) * jnp.eye(3) - Sigma_com
is_valid = V_ref > 1e-12
mass = jnp.where(is_valid, mass, 0.0)
I_com = jnp.where(is_valid, I_com, jnp.zeros((3, 3)))
return mass, I_com
@staticmethod
def compute_mass_and_inertia(
hw_link_metadata: HwLinkMetadata,
) -> tuple[jtp.Float, jtp.Matrix]:
"""
Compute the mass and inertia of a hardware link based on its metadata.
This function calculates the mass and inertia tensor of a hardware link
using its shape, dimensions, and density. The computation is performed
by using shape-specific methods.
Args:
hw_link_metadata: Metadata describing the hardware link,
including its shape, dimensions, and density.
Returns:
tuple: A tuple containing:
- mass: The computed mass of the hardware link.
- inertia: The computed inertia tensor of the hardware link.
"""
def box(dims, density, _moments) -> tuple[jtp.Float, jtp.Matrix]:
lx, ly, lz = dims
mass = density * lx * ly * lz
inertia = jnp.array(
[
[mass * (ly**2 + lz**2) / 12, 0, 0],
[0, mass * (lx**2 + lz**2) / 12, 0],
[0, 0, mass * (lx**2 + ly**2) / 12],
]
)
return mass, inertia
def cylinder(dims, density, _moments) -> tuple[jtp.Float, jtp.Matrix]:
r, l, _ = dims
mass = density * (jnp.pi * r**2 * l)
inertia = jnp.array(
[
[mass * (3 * r**2 + l**2) / 12, 0, 0],
[0, mass * (3 * r**2 + l**2) / 12, 0],
[0, 0, mass * (r**2) / 2],
]
)
return mass, inertia
def sphere(dims, density, _moments) -> tuple[jtp.Float, jtp.Matrix]:
r = dims[0]
mass = density * (4 / 3 * jnp.pi * r**3)
inertia = jnp.eye(3) * (2 / 5 * mass * r**2)
return mass, inertia
def mesh(dims, density, moments) -> tuple[jtp.Float, jtp.Matrix]:
return HwLinkMetadata.compute_mesh_inertia_from_moments(
moments, dims, density
)
def compute_mass_inertia(shape_idx, dims, density, moments):
def unsupported_case(_):
return (
jnp.asarray(0.0, dtype=density.dtype),
jnp.zeros((3, 3), dtype=density.dtype),
)
def supported_case(idx):
return jax.lax.switch(
idx, (box, cylinder, sphere, mesh), dims, density, moments
)
return jax.lax.cond(
shape_idx < 0, unsupported_case, supported_case, shape_idx
)
masses, inertias = jax.vmap(compute_mass_inertia)(
hw_link_metadata.link_shape,
hw_link_metadata.geometry,
hw_link_metadata.density,
hw_link_metadata.mesh_moments,
)
return masses, inertias
@staticmethod
def _convert_scaling_to_3d_vector(
link_shapes: jtp.Int, scaling_factors: jtp.Vector
) -> jtp.Vector:
"""
Convert scaling factors for specific shape dimensions into a 3D scaling vector.
Args:
link_shapes: The link_shapes of the link (e.g., box, sphere, cylinder, mesh).
scaling_factors: The scaling factors for the shape dimensions.
Returns:
A 3D scaling vector to apply to position vectors.
Note:
The scaling factors are applied as follows to generate the 3D scale vector:
- Box: [lx, ly, lz]
- Cylinder: [r, r, l]
- Sphere: [r, r, r]
- Mesh: [sx, sy, sz]
"""
# Index mapping for each shape type (link_shapes x 3 dims)
# Box: [lx, ly, lz] -> [0, 1, 2]
# Cylinder: [r, r, l] -> [0, 0, 1]
# Sphere: [r, r, r] -> [0, 0, 0]
# Mesh: [sx, sy, sz] -> [0, 1, 2]
shape_indices = jnp.array(
[
[0, 1, 2], # Box
[0, 0, 1], # Cylinder
[0, 0, 0], # Sphere
[0, 1, 2], # Mesh
]
)
# For each link, get the index vector for its shape
per_link_indices = shape_indices[link_shapes]
# Gather dims per link according to per_link_indices
return scaling_factors.dims[per_link_indices.squeeze()]
@staticmethod
def compute_contact_points(
original_contact_params: jtp.Vector,
link_shapes: jtp.Vector,
original_com_positions: jtp.Vector,
updated_com_positions: jtp.Vector,
scaling_factors: ScalingFactors,
) -> jtp.Matrix:
"""
Compute the new contact points based on the original contact parameters and
the scaling factors.
Args:
original_contact_params: The original contact parameters.
link_shapes: The shape types of the links (e.g., box, sphere, cylinder).
original_com_positions: The original center of mass positions of the links.
updated_com_positions: The updated center of mass positions of the links.
scaling_factors: The scaling factors for the link dimensions.
Returns:
The new contact points positions in the parent link frame.
"""
parent_link_indices = np.array(original_contact_params.body)
# Translate the original contact point positions in the origin, so
# that we can apply the scaling factors.
L_p_Ci = (
original_contact_params.point - original_com_positions[parent_link_indices]
)
# Extract the shape types of the parent links.
parent_link_shapes = jnp.array(link_shapes[parent_link_indices])
def sphere(parent_idx, L_p_C):
r = scaling_factors.dims[parent_idx][0]
return L_p_C * r
def cylinder(parent_idx, L_p_C):
# TODO: Cylinder collisions are not currently supported in JaxSim.
return L_p_C
def box(parent_idx, L_p_C):
lx, ly, lz = scaling_factors.dims[parent_idx]
return jnp.hstack(
[
L_p_C[0] * lx,
L_p_C[1] * ly,
L_p_C[2] * lz,
]
)
def mesh(parent_idx, L_p_C):
sx, sy, sz = scaling_factors.dims[parent_idx]
return jnp.hstack(
[
L_p_C[0] * sx,
L_p_C[1] * sy,
L_p_C[2] * sz,
]
)
new_positions = jax.vmap(
lambda shape_idx, parent_idx, L_p_C: jax.lax.switch(
shape_idx, (box, cylinder, sphere, mesh), parent_idx, L_p_C
)
)(
parent_link_shapes,
parent_link_indices,
L_p_Ci,
)
return new_positions + updated_com_positions[parent_link_indices]
@staticmethod
def compute_inertia_link(I_com, L_H_G) -> jtp.Matrix:
"""
Compute the inertia tensor of the link based on its shape and mass.
"""
L_R_G = L_H_G[:3, :3]
return L_R_G @ I_com @ L_R_G.T
@staticmethod
def apply_scaling(
has_joints: bool,
hw_metadata: HwLinkMetadata,
scaling_factors: ScalingFactors,
) -> HwLinkMetadata:
"""
Apply scaling to the hardware parameters and return a new HwLinkMetadata object.
Args:
has_joints: A boolean indicating if the model has joints.
hw_metadata: the original HwLinkMetadata object.
scaling_factors: the scaling factors to apply.
has_joints: whether the model has at least one joint.
Returns:
A new HwLinkMetadata object with updated parameters.
"""
scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector(
hw_metadata.link_shape, scaling_factors
)
# =================================
# Update the kinematics of the link
# =================================
# Get the nominal transforms
L_H_G = hw_metadata.L_H_G
L_H_vis = hw_metadata.L_H_vis
L_H_pre_array = hw_metadata.L_H_pre
L_H_pre_mask = hw_metadata.L_H_pre_mask
# Express the transforms in the G frame
G_H_L = jaxsim.math.Transform.inverse(L_H_G)
G_H_vis = G_H_L @ L_H_vis
G_H_pre_array = (
jax.vmap(lambda L_H_pre: G_H_L @ L_H_pre)(L_H_pre_array)
if has_joints
else L_H_pre_array
)
# Apply the scaling to the position vectors
G_H̅_vis = G_H_vis.at[:3, 3].set(scale_vector * G_H_vis[:3, 3])
# Apply scaling to the position vectors in G_H_pre_array based on the mask
G_H̅_pre_array = (
G_H_pre_array.at[:, :3, 3].set(
jnp.where(
L_H_pre_mask[:, None],
scale_vector[None, :] * G_H_pre_array[:, :3, 3],
G_H_pre_array[:, :3, 3],
)
)
if has_joints
else G_H_pre_array
)
# Get back to the link frame
L_H̅_G = L_H_G.at[:3, 3].set(scale_vector * L_H_G[:3, 3])
L_H̅_vis = L_H̅_G @ G_H̅_vis
L_H̅_pre_array = (
jax.vmap(lambda G_H̅_pre: L_H̅_G @ G_H̅_pre)(G_H̅_pre_array)
if has_joints
else G_H̅_pre_array
)
# ===========================
# Update the shape parameters
# ===========================
updated_geoms = hw_metadata.geometry * scaling_factors.dims
# =============================
# Scale the density of the link
# =============================
updated_density = hw_metadata.density * scaling_factors.density
# =============================
# Return updated HwLinkMetadata
# =============================
return hw_metadata.replace(
geometry=updated_geoms,
density=updated_density,
L_H_G=L_H̅_G,
L_H_vis=L_H̅_vis,
L_H_pre=L_H̅_pre_array,
)
@jax_dataclasses.pytree_dataclass
class ScalingFactors(JaxsimDataclass):
"""
Class storing scaling factors for hardware parameters.
Attributes:
dims: Scaling factors for shape dimensions.
density: Scaling factor for density.
"""
dims: jtp.Vector
density: jtp.Float
@dataclasses.dataclass(frozen=True)
class ConstraintType:
"""
Enumeration of all supported constraint types.
"""
Weld: ClassVar[int] = 0
# TODO: handle Connect constraint
# Connect: ClassVar[int] = 1
@jax_dataclasses.pytree_dataclass
class ConstraintMap(JaxsimDataclass):
"""
Class storing the kinematic constraints of a model.
"""
frame_idxs_1: jtp.Int = dataclasses.field(
default_factory=lambda: jnp.array([], dtype=int)
)
frame_idxs_2: jtp.Int = dataclasses.field(
default_factory=lambda: jnp.array([], dtype=int)
)
constraint_types: jtp.Int = dataclasses.field(
default_factory=lambda: jnp.array([], dtype=int)
)
K_P: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array([], dtype=float)
)
K_D: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array([], dtype=float)
)
# Precomputed parent link indices for each constraint pair
parent_link_idxs_1: jtp.Int = dataclasses.field(
default_factory=lambda: jnp.array([], dtype=int)
)
parent_link_idxs_2: jtp.Int = dataclasses.field(
default_factory=lambda: jnp.array([], dtype=int)
)
def add_constraint(
self,
model: jaxsim.api.model.JaxSimModel,
frame_idx_1: int,
frame_idx_2: int,
constraint_type: int,
K_P: float | None = None,
K_D: float | None = None,
) -> ConstraintMap:
"""
Add a constraint to the constraint map.
Args:
model: The model for which the constraints are added.
frame_idx_1: The index of the first frame.
frame_idx_2: The index of the second frame.
constraint_type: The type of constraint.
K_P: The proportional gain for Baumgarte stabilization (default: 1000).
K_D: The derivative gain for Baumgarte stabilization (default: 2 * sqrt(K_P)).
Returns:
A new ConstraintMap instance with the added constraint.
Note:
Since this method returns a new instance of ConstraintMap with the new constraint,
it will trigger recompilations in JIT-compiled functions.
"""
# Set default values for Baumgarte coefficients if not provided
if K_P is None:
K_P = jnp.array([1000.0])
if K_D is None:
K_D = 2 * jnp.sqrt(K_P)
# Create new arrays with the input elements appended
new_frame_idxs_1 = jnp.append(self.frame_idxs_1, frame_idx_1)
new_frame_idxs_2 = jnp.append(self.frame_idxs_2, frame_idx_2)
new_constraint_types = jnp.append(self.constraint_types, constraint_type)
new_K_P = jnp.append(self.K_P, K_P)
new_K_D = jnp.append(self.K_D, K_D)
# Compute parent link indices (now always available since model is required)
parent_link_idx_1 = jaxsim.api.frame.idx_of_parent_link(
model, frame_index=frame_idx_1
)
parent_link_idx_2 = jaxsim.api.frame.idx_of_parent_link(
model, frame_index=frame_idx_2
)
new_parent_link_idxs_1 = jnp.append(self.parent_link_idxs_1, parent_link_idx_1)
new_parent_link_idxs_2 = jnp.append(self.parent_link_idxs_2, parent_link_idx_2)
# Return a new ConstraintMap object with updated attributes
return ConstraintMap(
frame_idxs_1=new_frame_idxs_1,
frame_idxs_2=new_frame_idxs_2,
constraint_types=new_constraint_types,
K_P=new_K_P,
K_D=new_K_D,
parent_link_idxs_1=new_parent_link_idxs_1,
parent_link_idxs_2=new_parent_link_idxs_2,
)
================================================
FILE: src/jaxsim/api/link.py
================================================
import functools
from collections.abc import Sequence
import jax
import jax.numpy as jnp
import jax.scipy.linalg
import numpy as np
import jaxsim.api as js
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim import exceptions
from jaxsim.math import Adjoint
from .common import VelRepr
# =======================
# Index-related functions
# =======================
@functools.partial(jax.jit, static_argnames="link_name")
def name_to_idx(model: js.model.JaxSimModel, *, link_name: str) -> jtp.Int:
"""
Convert the name of a link to its index.
Args:
model: The model to consider.
link_name: The name of the link.
Returns:
The index of the link.
"""
if link_name not in model.link_names():
raise ValueError(f"Link '{link_name}' not found in the model.")
return (
jnp.array(model.kin_dyn_parameters.link_names.index(link_name))
.astype(int)
.squeeze()
)
def idx_to_name(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
"""
Convert the index of a link to its name.
Args:
model: The model to consider.
link_index: The index of the link.
Returns:
The name of the link.
"""
exceptions.raise_value_error_if(
condition=link_index < 0,
msg="Invalid link index '{idx}'",
idx=link_index,
)
return model.kin_dyn_parameters.link_names[link_index]
@functools.partial(jax.jit, static_argnames="link_names")
def names_to_idxs(
model: js.model.JaxSimModel, *, link_names: Sequence[str]
) -> jax.Array:
"""
Convert a sequence of link names to their corresponding indices.
Args:
model: The model to consider.
link_names: The names of the links.
Returns:
The indices of the links.
"""
return jnp.array(
[name_to_idx(model=model, link_name=name) for name in link_names],
).astype(int)
def idxs_to_names(
model: js.model.JaxSimModel, *, link_indices: Sequence[jtp.IntLike] | jtp.VectorLike
) -> tuple[str, ...]:
"""
Convert a sequence of link indices to their corresponding names.
Args:
model: The model to consider.
link_indices: The indices of the links.
Returns:
The names of the links.
"""
return tuple(np.array(model.kin_dyn_parameters.link_names)[list(link_indices)])
# =========
# Link APIs
# =========
@jax.jit
def mass(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Float:
"""
Return the mass of the link.
Args:
model: The model to consider.
link_index: The index of the link.
Returns:
The mass of the link.
"""
exceptions.raise_value_error_if(
condition=jnp.array(
[link_index < 0, link_index >= model.number_of_links()]
).any(),
msg="Invalid link index '{idx}'",
idx=link_index,
)
return model.kin_dyn_parameters.link_parameters.mass[link_index].astype(float)
@jax.jit
def spatial_inertia(
model: js.model.JaxSimModel, *, link_index: jtp.IntLike
) -> jtp.Matrix:
r"""
Compute the 6D spatial inertial of the link.
Args:
model: The model to consider.
link_index: The index of the link.
Returns:
The :math:`6 \times 6` matrix representing the spatial inertia of the link expressed in
the link frame (body-fixed representation).
"""
exceptions.raise_value_error_if(
condition=jnp.array(
[link_index < 0, link_index >= model.number_of_links()]
).any(),
msg="Invalid link index '{idx}'",
idx=link_index,
)
link_parameters = jax.tree.map(
lambda l: l[link_index], model.kin_dyn_parameters.link_parameters
)
return js.kin_dyn_parameters.LinkParameters.spatial_inertia(link_parameters)
@jax.jit
def transform(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_index: jtp.IntLike,
) -> jtp.Matrix:
"""
Compute the SE(3) transform from the world frame to the link frame.
Args:
model: The model to consider.
data: The data of the considered model.
link_index: The index of the link.
Returns:
The 4x4 matrix representing the transform.
"""
exceptions.raise_value_error_if(
condition=jnp.array(
[link_index < 0, link_index >= model.number_of_links()]
).any(),
msg="Invalid link index '{idx}'",
idx=link_index,
)
return data._link_transforms[link_index]
@jax.jit
def com_position(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_index: jtp.IntLike,
in_link_frame: jtp.BoolLike = True,
) -> jtp.Vector:
"""
Compute the position of the center of mass of the link.
Args:
model: The model to consider.
data: The data of the considered model.
link_index: The index of the link.
in_link_frame:
Whether to return the position in the link frame or in the world frame.
Returns:
The 3D position of the center of mass of the link.
"""
from jaxsim.math.inertia import Inertia
_, L_p_CoM, _ = Inertia.to_params(
M=spatial_inertia(model=model, link_index=link_index)
)
def com_in_link_frame():
return L_p_CoM.squeeze()
def com_in_inertial_frame():
W_H_L = transform(link_index=link_index, model=model, data=data)
W_p̃_CoM = W_H_L @ jnp.hstack([L_p_CoM.squeeze(), 1])
return W_p̃_CoM[0:3].squeeze()
return jax.lax.select(
pred=in_link_frame,
on_true=com_in_link_frame(),
on_false=com_in_inertial_frame(),
)
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
def jacobian(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_index: jtp.IntLike,
output_vel_repr: VelRepr | None = None,
) -> jtp.Matrix:
r"""
Compute the free-floating jacobian of the link.
Args:
model: The model to consider.
data: The data of the considered model.
link_index: The index of the link.
output_vel_repr:
The output velocity representation of the free-floating jacobian.
Returns:
The :math:`6 \times (6+n)` free-floating jacobian of the link.
Note:
The input representation of the free-floating jacobian is the active
velocity representation.
"""
exceptions.raise_value_error_if(
condition=jnp.array(
[link_index < 0, link_index >= model.number_of_links()]
).any(),
msg="Invalid link index '{idx}'",
idx=link_index,
)
output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)
# Compute the doubly-left free-floating full jacobian.
B_J_full_WX_B, B_H_Li = jaxsim.rbda.jacobian_full_doubly_left(
model=model,
joint_positions=data.joint_positions,
)
# Compute the actual doubly-left free-floating jacobian of the link.
κb = model.kin_dyn_parameters.support_body_array_bool[link_index]
B_J_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J_full_WX_B
# Adjust the input representation such that `J_WL_I @ I_ν`.
match data.velocity_representation:
case VelRepr.Inertial:
W_H_B = data._base_transform
B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841
B_X_W, jnp.eye(model.dofs())
)
case VelRepr.Body:
B_J_WL_I = B_J_WL_B
case VelRepr.Mixed:
W_R_B = jaxsim.math.Quaternion.to_dcm(data.base_orientation)
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841
B_X_BW, jnp.eye(model.dofs())
)
case _:
raise ValueError(data.velocity_representation)
B_H_L = B_H_Li[link_index]
# Adjust the output representation such that `O_v_WL_I = O_J_WL_I @ I_ν`.
match output_vel_repr:
case VelRepr.Inertial:
W_H_B = data._base_transform
W_X_B = Adjoint.from_transform(transform=W_H_B)
O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I # noqa: F841
case VelRepr.Body:
L_X_B = Adjoint.from_transform(transform=B_H_L, inverse=True)
L_J_WL_I = L_X_B @ B_J_WL_I
O_J_WL_I = L_J_WL_I
case VelRepr.Mixed:
W_H_B = data._base_transform
W_H_L = W_H_B @ B_H_L
LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3))
LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
LW_X_B = Adjoint.from_transform(transform=LW_H_B)
LW_J_WL_I = LW_X_B @ B_J_WL_I
O_J_WL_I = LW_J_WL_I
case _:
raise ValueError(output_vel_repr)
return O_J_WL_I
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
def velocity(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_index: jtp.IntLike,
output_vel_repr: VelRepr | None = None,
) -> jtp.Vector:
"""
Compute the 6D velocity of the link.
Args:
model: The model to consider.
data: The data of the considered model.
link_index: The index of the link.
output_vel_repr:
The output velocity representation of the link velocity.
Returns:
The 6D velocity of the link in the specified velocity representation.
"""
exceptions.raise_value_error_if(
condition=jnp.array(
[link_index < 0, link_index >= model.number_of_links()]
).any(),
msg="Invalid link index '{idx}'",
idx=link_index,
)
output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)
# Get the link jacobian having I as input representation (taken from data)
# and O as output representation, specified by the user (or taken from data).
O_J_WL_I = jacobian(
model=model,
data=data,
link_index=link_index,
output_vel_repr=output_vel_repr,
)
# Get the generalized velocity in the input velocity representation.
I_ν = data.generalized_velocity
# Compute the link velocity in the output velocity representation.
return O_J_WL_I @ I_ν
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
def jacobian_derivative(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_index: jtp.IntLike,
output_vel_repr: VelRepr | None = None,
) -> jtp.Matrix:
r"""
Compute the derivative of the free-floating jacobian of the link.
Args:
model: The model to consider.
data: The data of the considered model.
link_index: The index of the link.
output_vel_repr:
The output velocity representation of the free-floating jacobian derivative.
Returns:
The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the link.
Note:
The input representation of the free-floating jacobian derivative is the active
velocity representation.
"""
exceptions.raise_value_error_if(
condition=jnp.array(
[link_index < 0, link_index >= model.number_of_links()]
).any(),
msg="Invalid link index '{idx}'",
idx=link_index,
)
output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)
O_J̇_WL_I = js.model.generalized_free_floating_jacobian_derivative(
model=model, data=data, output_vel_repr=output_vel_repr
)[link_index]
return O_J̇_WL_I
@jax.jit
def bias_acceleration(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_index: jtp.IntLike,
) -> jtp.Vector:
"""
Compute the bias acceleration of the link.
Args:
model: The model to consider.
data: The data of the considered model.
link_index: The index of the link.
Returns:
The 6D bias acceleration of the link.
"""
exceptions.raise_value_error_if(
condition=jnp.array(
[link_index < 0, link_index >= model.number_of_links()]
).any(),
msg="Invalid link index '{idx}'",
idx=link_index,
)
# Compute the bias acceleration of all links in the active representation.
O_v̇_WL = js.model.link_bias_accelerations(model=model, data=data)[link_index]
return O_v̇_WL
================================================
FILE: src/jaxsim/api/model.py
================================================
from __future__ import annotations
import copy
import dataclasses
import enum
import functools
import pathlib
from collections.abc import Sequence
import jax
import jax.numpy as jnp
import jax_dataclasses
import numpy as np
import rod
from jax_dataclasses import Static
from rod.urdf.exporter import UrdfExporter
import jaxsim.api as js
import jaxsim.terrain
import jaxsim.typing as jtp
from jaxsim import logging
from jaxsim.api.kin_dyn_parameters import (
HwLinkMetadata,
KinDynParameters,
LinkParameters,
LinkParametrizableShape,
ScalingFactors,
)
from jaxsim.math import Adjoint, Cross, Skew
from jaxsim.parsers.descriptions import ModelDescription
from jaxsim.parsers.descriptions.joint import JointDescription
from jaxsim.parsers.descriptions.link import LinkDescription
from jaxsim.parsers.rod.utils import prepare_mesh_for_parametrization
from jaxsim.utils import JaxsimDataclass, Mutability, wrappers
from .common import VelRepr
class IntegratorType(enum.IntEnum):
"""The integrators available for the simulation."""
SemiImplicitEuler = enum.auto()
RungeKutta4 = enum.auto()
RungeKutta4Fast = enum.auto()
@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
class JaxSimModel(JaxsimDataclass):
"""
The JaxSim model defining the kinematics and dynamics of a robot.
"""
model_name: Static[str]
time_step: float = dataclasses.field(
default=0.001,
)
terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
)
gravity: Static[float] = -jaxsim.math.STANDARD_GRAVITY
contact_model: Static[jaxsim.rbda.contacts.ContactModel | None] = dataclasses.field(
default=None, repr=False
)
contact_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(
default=None, repr=False
)
actuation_params: Static[jaxsim.rbda.actuation.ActuationParams] = dataclasses.field(
default=None, repr=False
)
kin_dyn_parameters: js.kin_dyn_parameters.KinDynParameters | None = (
dataclasses.field(default=None, repr=False)
)
integrator: Static[IntegratorType] = dataclasses.field(
default=IntegratorType.SemiImplicitEuler, repr=False
)
built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(
default=None, repr=False
)
_description: Static[wrappers.HashlessObject[ModelDescription | None]] = (
dataclasses.field(default=None, repr=False)
)
@property
def description(self) -> ModelDescription:
"""
Return the model description.
"""
return self._description.get()
def __eq__(self, other: JaxSimModel) -> bool:
if not isinstance(other, JaxSimModel):
return False
if self.model_name != other.model_name:
return False
if self.time_step != other.time_step:
return False
if self.kin_dyn_parameters != other.kin_dyn_parameters:
return False
return True
def __hash__(self) -> int:
return hash(
(
hash(self.model_name),
hash(self.time_step),
hash(self.kin_dyn_parameters),
hash(self.contact_model),
)
)
# ========================
# Initialization and state
# ========================
@classmethod
def build_from_model_description(
cls,
model_description: str | pathlib.Path | rod.Model,
*,
model_name: str | None = None,
time_step: jtp.FloatLike | None = None,
terrain: jaxsim.terrain.Terrain | None = None,
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
contact_params: jaxsim.rbda.contacts.ContactsParams | None = None,
actuation_params: jaxsim.rbda.actuation.ActuationParams | None = None,
integrator: IntegratorType | None = None,
is_urdf: bool | None = None,
considered_joints: Sequence[str] | None = None,
gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
constraints: jaxsim.rbda.kinematic_constraints.ConstraintMap | None = None,
parametrized_links: tuple[str, ...] | None = None,
) -> JaxSimModel:
"""
Build a Model object from a model description.
Args:
model_description:
A path to an SDF/URDF file, a string containing
its content, or a pre-parsed/pre-built rod model.
model_name:
The name of the model. If not specified, it is read from the description.
time_step:
The default time step to consider for the simulation. It can be
manually overridden in the function that steps the simulation.
terrain: The terrain to consider (the default is a flat infinite plane).
contact_model:
The contact model to consider.
If not specified, a soft contacts model is used.
contact_params: The parameters of the contact model.
actuation_params: The parameters of the actuation model.
integrator: The integrator to use for the simulation.
is_urdf:
The optional flag to force the model description to be parsed as a URDF.
This is usually automatically inferred.
considered_joints:
The list of joints to consider. If None, all joints are considered.
gravity: The gravity constant. Normally passed as a positive value.
constraints:
An object of type ConstraintMap containing the kinematic constraints to consider. If None, no constraints are considered.
Note that constraints can be used only with RelaxedRigidContacts.
parametrized_links:
The optional list of links to be parametrized. If None, all links are parametrized.
Returns:
The built Model object.
"""
import jaxsim.parsers.rod
# Parse the input resource (either a path to file or a string with the URDF/SDF)
# and build the -intermediate- model description.
intermediate_description = jaxsim.parsers.rod.build_model_description(
model_description=model_description, is_urdf=is_urdf
)
# Lump links together if not all joints are considered.
# Note: this procedure assigns a zero position to all joints not considered.
if considered_joints is not None:
intermediate_description = intermediate_description.reduce(
considered_joints=considered_joints
)
# Build the model.
model = cls.build(
model_description=intermediate_description,
model_name=model_name,
time_step=time_step,
terrain=terrain,
contact_model=contact_model,
actuation_params=actuation_params,
contact_params=contact_params,
integrator=integrator,
gravity=-gravity,
constraints=constraints,
parametrized_links=parametrized_links,
)
# Store the origin of the model, in case downstream logic needs it.
with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
model.built_from = model_description
# Compute the hw parametrization metadata of the model
# TODO: move the building of the metadata to KinDynParameters.build()
# and use the model_description instead of model.built_from.
with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
model.kin_dyn_parameters.hw_link_metadata = model.compute_hw_link_metadata(
parametrized_links=parametrized_links
)
return model
@classmethod
def build(
cls,
model_description: ModelDescription,
*,
model_name: str | None = None,
time_step: jtp.FloatLike | None = None,
terrain: jaxsim.terrain.Terrain | None = None,
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
contact_params: jaxsim.rbda.contacts.ContactsParams | None = None,
actuation_params: jaxsim.rbda.actuation.ActuationParams | None = None,
integrator: IntegratorType | None = None,
gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
constraints: jaxsim.rbda.kinematic_constraints.ConstraintMap | None = None,
parametrized_links: tuple[str, ...] | None = None,
) -> JaxSimModel:
"""
Build a Model object from an intermediate model description.
Args:
model_description:
The intermediate model description defining the kinematics and dynamics
of the model.
model_name:
The name of the model. If not specified, it is read from the description.
time_step:
The default time step to consider for the simulation. It can be
manually overridden in the function that steps the simulation.
terrain: The terrain to consider (the default is a flat infinite plane).
The optional name of the model overriding the physics model name.
contact_model:
The contact model to consider.
If not specified, a soft contact model is used.
contact_params: The parameters of the contact model.
actuation_params: The parameters of the actuation model.
integrator: The integrator to use for the simulation.
gravity: The gravity constant.
constraints:
An object of type ConstraintMap containing the kinematic constraints to consider. If None, no constraints are considered.
parametrized_links:
The optional list of links to be parametrized. If None, all links are parametrized.
Returns:
The built Model object.
"""
# Set the model name (if not provided, use the one from the model description).
model_name = model_name if model_name is not None else model_description.name
# Consider the default terrain (a flat infinite plane) if not specified.
terrain = (
terrain
if terrain is not None
else JaxSimModel.__dataclass_fields__["terrain"].default_factory()
)
# Consider the default time step if not specified.
time_step = (
time_step
if time_step is not None
else JaxSimModel.__dataclass_fields__["time_step"].default
)
# Create the default contact model.
# It will be populated with an initial estimation of good parameters.
# While these might not be the best, they are a good starting point.
contact_model = (
contact_model
if contact_model is not None
else jaxsim.rbda.contacts.SoftContacts.build()
)
if contact_params is None:
contact_params = contact_model._parameters_class()
if actuation_params is None:
actuation_params = jaxsim.rbda.actuation.ActuationParams()
# Consider the default integrator if not specified.
integrator = (
integrator
if integrator is not None
else JaxSimModel.__dataclass_fields__["integrator"].default
)
# Build the model.
model = cls(
model_name=model_name,
kin_dyn_parameters=js.kin_dyn_parameters.KinDynParameters.build(
model_description=model_description, constraints=constraints
),
time_step=time_step,
terrain=terrain,
contact_model=contact_model,
contact_params=contact_params,
actuation_params=actuation_params,
integrator=integrator,
gravity=gravity,
# The following is wrapped as hashless since it's a static argument, and we
# don't want to trigger recompilation if it changes. All relevant parameters
# needed to compute kinematics and dynamics quantities are stored in the
# kin_dyn_parameters attribute.
_description=wrappers.HashlessObject(obj=model_description),
)
return model
def compute_hw_link_metadata(
self, parametrized_links: tuple[str, ...] | None = None
) -> HwLinkMetadata:
"""
Compute the parametric metadata of the links in the model.
Args:
parametrized_links:
An optional tuple of link names to be parametrized. If None,
all links will be parametrized.
Returns:
An instance of HwLinkMetadata containing the metadata of all links.
"""
model_description = self.description
# Get ordered links and joints from the model description
ordered_links: list[LinkDescription] = sorted(
list(model_description.links_dict.values()),
key=lambda l: l.index,
)
ordered_joints: list[JointDescription] = sorted(
list(model_description.joints_dict.values()),
key=lambda j: j.index,
)
# Ensure the model was built from a valid source
rod_model = None
match self.built_from:
case str() | pathlib.Path():
rod_model = rod.Sdf.load(sdf=self.built_from).models()[0]
assert rod_model.name == self.name()
case rod.Model():
rod_model = self.built_from
case _:
logging.debug(
f"Invalid type for model.built_from ({type(self.built_from)})."
"Skipping for hardware parametrization."
)
return HwLinkMetadata.empty()
# Use URDF frame convention for consistent pose representation
rod_model.switch_frame_convention(
frame_convention=rod.FrameConvention.Urdf, explicit_frames=True
)
rod_links_dict = {link.name: link for link in rod_model.links()}
# Initialize lists to collect metadata for all links
shapes = []
geoms = []
densities = []
L_H_Gs = []
L_H_vises = []
L_H_pre_masks = []
L_H_pre = []
mesh_moments_list = []
mesh_vertices = []
mesh_faces = []
mesh_offsets = []
mesh_uris = []
# Process each link, only parametrizing those in parametrized_links if provided
for link_description in ordered_links:
link_name = link_description.name
if parametrized_links is not None and link_name not in parametrized_links:
# Mark as unsupported for non-parametrized links
shapes.append(LinkParametrizableShape.Unsupported)
geoms.append([0, 0, 0])
densities.append(0.0)
L_H_Gs.append(jnp.eye(4))
L_H_vises.append(jnp.eye(4))
L_H_pre_masks.append([0] * self.number_of_joints())
L_H_pre.append([jnp.eye(4)] * self.number_of_joints())
mesh_vertices.append(None)
mesh_faces.append(None)
mesh_offsets.append(None)
mesh_uris.append(None)
mesh_moments_list.append(np.zeros(13))
continue
rod_link = rod_links_dict.get(link_name)
link_index = int(js.link.name_to_idx(model=self, link_name=link_name))
# Get child joints for the link
child_joints_indices = [
js.joint.name_to_idx(model=self, joint_name=j.name)
for j in ordered_joints
if j.parent.name == link_name
]
# Skip unsupported links
if not jnp.allclose(
self.kin_dyn_parameters.joint_model.suc_H_i[link_index],
jnp.eye(4),
**(dict(atol=1e-6) if not jax.config.jax_enable_x64 else {}),
):
logging.debug(
f"Skipping link '{link_name}' for hardware parametrization due to unsupported suc_H_link."
)
rod_link = None
# Compute density and dimensions
mass = float(self.kin_dyn_parameters.link_parameters.mass[link_index])
# Find the first supported visual
supported_visual = (
next(
(
v
for v in rod_link.visuals()
if isinstance(
v.geometry.geometry(),
(rod.Box, rod.Sphere, rod.Cylinder, rod.Mesh),
)
),
None,
)
if rod_link
else None
)
geometry = (
supported_visual.geometry.geometry() if supported_visual else None
)
if isinstance(geometry, rod.Box):
lx, ly, lz = geometry.size
density = mass / (lx * ly * lz)
geom = [lx, ly, lz]
shape = LinkParametrizableShape.Box
mesh_vertices.append(None)
mesh_faces.append(None)
mesh_offsets.append(None)
mesh_uris.append(None)
mesh_moments_list.append(np.zeros(13))
elif isinstance(geometry, rod.Sphere):
r = geometry.radius
density = mass / (4 / 3 * jnp.pi * r**3)
geom = [r, 0, 0]
shape = LinkParametrizableShape.Sphere
mesh_vertices.append(None)
mesh_faces.append(None)
mesh_offsets.append(None)
mesh_uris.append(None)
mesh_moments_list.append(np.zeros(13))
elif isinstance(geometry, rod.Cylinder):
r, l = geometry.radius, geometry.length
density = mass / (jnp.pi * r**2 * l)
geom = [r, l, 0]
shape = LinkParametrizableShape.Cylinder
mesh_vertices.append(None)
mesh_faces.append(None)
mesh_offsets.append(None)
mesh_uris.append(None)
mesh_moments_list.append(np.zeros(13))
elif isinstance(geometry, rod.Mesh):
# Load and prepare mesh for parametric scaling
try:
mesh_data = prepare_mesh_for_parametrization(
mesh_uri=geometry.uri,
scale=geometry.scale,
)
density = (
mass / mesh_data["volume"] if mesh_data["volume"] > 0 else 0.0
)
# For meshes, store cumulative scale factors (initially 1.0) in geometry
# instead of bounding box extents. This allows proper multiplicative scaling.
geom = [1.0, 1.0, 1.0]
shape = LinkParametrizableShape.Mesh
# Store mesh data
mesh_vertices.append(mesh_data["vertices"])
mesh_faces.append(mesh_data["faces"])
mesh_offsets.append(mesh_data["offset"])
mesh_uris.append(mesh_data["uri"])
# Precompute volumetric moments for JIT-friendly inertia computation
mesh_moments_list.append(
HwLinkMetadata.precompute_mesh_moments(
mesh_data["vertices"], mesh_data["faces"]
)
)
logging.info(
f"Loaded mesh for link '{link_name}': "
f"{len(mesh_data['vertices'])} vertices, "
f"{len(mesh_data['faces'])} faces, "
)
except Exception as e:
logging.warning(
f"Failed to load mesh for link '{link_name}': {e}. "
f"Marking as unsupported."
)
density = 0.0
geom = [0, 0, 0]
shape = LinkParametrizableShape.Unsupported
mesh_vertices.append(None)
mesh_faces.append(None)
mesh_offsets.append(None)
mesh_uris.append(None)
mesh_moments_list.append(np.zeros(13))
else:
logging.debug(
f"Skipping link '{link_name}' for hardware parametrization due to unsupported geometry."
)
density = 0.0
geom = [0, 0, 0]
shape = LinkParametrizableShape.Unsupported
mesh_vertices.append(None)
mesh_faces.append(None)
mesh_offsets.append(None)
mesh_uris.append(None)
mesh_moments_list.append(np.zeros(13))
inertial_pose = (
rod_link.inertial.pose.transform() if rod_link else jnp.eye(4)
)
visual_pose = (
supported_visual.pose.transform() if supported_visual else jnp.eye(4)
)
l_h_pre_mask = [
int(joint_index in child_joints_indices)
for joint_index in range(self.number_of_joints())
]
l_h_pre = [
(
self.kin_dyn_parameters.joint_model.λ_H_pre[joint_index + 1]
if joint_index in child_joints_indices
else jnp.eye(4)
)
for joint_index in range(self.number_of_joints())
]
shapes.append(shape)
geoms.append(geom)
densities.append(density)
L_H_Gs.append(inertial_pose)
L_H_vises.append(visual_pose)
L_H_pre_masks.append(l_h_pre_mask)
L_H_pre.append(l_h_pre)
if np.all(np.array(shapes) == LinkParametrizableShape.Unsupported):
logging.debug(
"All links were skipped for hardware parametrization. Returning empty metadata."
)
return HwLinkMetadata.empty()
# Stack collected data into JAX arrays
# Handle L_H_pre specially: ensure shape (n_links, n_joints, 4, 4) even when n_joints=0
L_H_pre_array = jnp.array(L_H_pre, dtype=float)
if self.number_of_joints() == 0:
# Reshape from (n_links, 0) to (n_links, 0, 4, 4)
n_links = len(L_H_pre)
L_H_pre_array = L_H_pre_array.reshape(n_links, 0, 4, 4)
return HwLinkMetadata(
link_shape=jnp.array(shapes, dtype=int),
geometry=jnp.array(geoms, dtype=float),
density=jnp.array(densities, dtype=float),
L_H_G=jnp.array(L_H_Gs, dtype=float),
L_H_vis=jnp.array(L_H_vises, dtype=float),
L_H_pre_mask=jnp.array(L_H_pre_masks, dtype=bool),
L_H_pre=L_H_pre_array,
mesh_moments=jnp.array(np.stack(mesh_moments_list), dtype=float),
mesh_vertices=(
tuple(
wrappers.HashedNumpyArray(array=v) if v is not None else None
for v in mesh_vertices
)
if any(v is not None for v in mesh_vertices)
else None
),
mesh_faces=(
tuple(
wrappers.HashedNumpyArray(array=f) if f is not None else None
for f in mesh_faces
)
if any(f is not None for f in mesh_faces)
else None
),
mesh_offset=(
tuple(
wrappers.HashedNumpyArray(array=o) if o is not None else None
for o in mesh_offsets
)
if any(o is not None for o in mesh_offsets)
else None
),
mesh_uri=(
tuple(mesh_uris) if any(u is not None for u in mesh_uris) else None
),
)
def export_updated_model(self) -> str:
"""
Export the JaxSim model to URDF with the current hardware parameters.
Returns:
The URDF string of the updated model.
Note:
This method is not meant to be used in JIT-compiled functions.
"""
if isinstance(jnp.zeros(0), jax.core.Tracer):
raise RuntimeError("This method cannot be used in JIT-compiled functions")
# Ensure `built_from` is a ROD model and create `rod_model_output`
if isinstance(self.built_from, rod.Model):
rod_model_output = copy.deepcopy(self.built_from)
elif isinstance(self.built_from, (str, pathlib.Path)):
rod_model_output = rod.Sdf.load(sdf=self.built_from).models()[0]
else:
raise ValueError(
"The JaxSim model must be built from a valid ROD model source"
)
# Switch to URDF frame convention for easier mapping
rod_model_output.switch_frame_convention(
frame_convention=rod.FrameConvention.Urdf,
explicit_frames=True,
attach_frames_to_links=True,
)
# Get links and joints from the ROD model
links_dict = {link.name: link for link in rod_model_output.links()}
joints_dict = {joint.name: joint for joint in rod_model_output.joints()}
# Iterate over the hardware metadata to update the ROD model
hw_metadata = self.kin_dyn_parameters.hw_link_metadata
reduced_link_names = set(self.link_names())
reduced_joint_names = set(self.joint_names())
unit_scale = np.ones(3, dtype=float)
link_scale_factors: dict[str, np.ndarray] = {}
def collect_link_elements(link) -> list:
elements_to_update_raw = (link.visual, link.collision)
elements_to_update = []
for entry in elements_to_update_raw:
if entry is None:
continue
if isinstance(entry, (list, tuple)):
elements_to_update.extend(e for e in entry if e is not None)
else:
elements_to_update.append(entry)
return elements_to_update
def scale_pose_translation(element, scale_vector):
if getattr(element, "pose", None) is None:
return
transform = np.array(element.pose.transform(), dtype=float)
transform[0:3, 3] = scale_vector * transform[0:3, 3]
element.pose = rod.Pose.from_transform(
transform=transform,
relative_to=element.pose.relative_to,
)
def scale_link_elements(
elements_to_update: list,
scale_vector: np.ndarray,
*,
mesh_pose: rod.Pose | None = None,
mesh_shape_link: bool = False,
) -> None:
for element in elements_to_update:
if (
element is None
or not hasattr(element, "geometry")
or element.geometry is None
):
continue
geometry = element.geometry
if getattr(geometry, "box", None) is not None:
current_size = np.array(geometry.box.size, dtype=float)
geometry.box.size = tuple(
float(v) for v in (current_size * scale_vector).tolist()
)
scale_pose_translation(element, scale_vector)
elif getattr(geometry, "sphere", None) is not None:
geometry.sphere.radius = float(
float(geometry.sphere.radius) * float(scale_vector[0])
)
scale_pose_translation(element, scale_vector)
elif getattr(geometry, "cylinder", None) is not None:
geometry.cylinder.radius = float(
float(geometry.cylinder.radius) * float(scale_vector[0])
)
geometry.cylinder.length = float(
float(geometry.cylinder.length) * float(scale_vector[2])
)
scale_pose_translation(element, scale_vector)
elif getattr(geometry, "mesh", None) is not None:
base_scale = (
np.array(geometry.mesh.scale, dtype=float)
if geometry.mesh.scale is not None
else unit_scale
)
geometry.mesh.scale = tuple(
float(v) for v in (base_scale * scale_vector).tolist()
)
# Mesh-parametrized reduced links use metadata to preserve
# the main visual placement in the exported URDF.
if mesh_shape_link and mesh_pose is not None:
element.pose = mesh_pose
else:
scale_pose_translation(element, scale_vector)
for link_index, link_name in enumerate(self.link_names()):
if link_name not in links_dict:
continue
# Skip links with unsupported shapes
shape = hw_metadata.link_shape[link_index]
if shape == LinkParametrizableShape.Unsupported:
logging.debug(f"Skipping link '{link_name}' with unsupported shape")
continue
# Update mass and inertia
mass = float(self.kin_dyn_parameters.link_parameters.mass[link_index])
center_of_mass = np.array(
self.kin_dyn_parameters.link_parameters.center_of_mass[link_index]
)
inertia_tensor = LinkParameters.unflatten_inertia_tensor(
self.kin_dyn_parameters.link_parameters.inertia_elements[link_index]
)
links_dict[link_name].inertial.mass = mass
L_H_COM = np.eye(4)
L_H_COM[0:3, 3] = center_of_mass
links_dict[link_name].inertial.pose = rod.Pose.from_transform(
transform=L_H_COM,
relative_to=links_dict[link_name].inertial.pose.relative_to,
)
links_dict[link_name].inertial.inertia = rod.Inertia.from_inertia_tensor(
inertia_tensor=inertia_tensor, validate=True
)
dims = np.array(hw_metadata.geometry[link_index], dtype=float)
elements_to_update = collect_link_elements(links_dict[link_name])
def find_reference_geometry(attr: str, elements: list = elements_to_update):
for element in elements:
if (
element is None
or not hasattr(element, "geometry")
or element.geometry is None
):
continue
geometry = getattr(element.geometry, attr, None)
if geometry is not None:
return geometry
return None
if shape == LinkParametrizableShape.Mesh:
scale_vector = dims
elif shape == LinkParametrizableShape.Box:
ref_box = find_reference_geometry("box")
if ref_box is None:
scale_vector = unit_scale
else:
base_size = np.array(ref_box.size, dtype=float)
scale_vector = np.divide(
dims,
base_size,
out=np.ones(3, dtype=float),
where=np.abs(base_size) > 1e-12,
)
elif shape == LinkParametrizableShape.Sphere:
ref_sphere = find_reference_geometry("sphere")
base_radius = (
float(ref_sphere.radius) if ref_sphere is not None else 1.0
)
s = float(dims[0]) / base_radius if abs(base_radius) > 1e-12 else 1.0
scale_vector = np.array([s, s, s], dtype=float)
elif shape == LinkParametrizableShape.Cylinder:
ref_cylinder = find_reference_geometry("cylinder")
base_radius = (
float(ref_cylinder.radius) if ref_cylinder is not None else 1.0
)
base_length = (
float(ref_cylinder.length) if ref_cylinder is not None else 1.0
)
s_radius = (
float(dims[0]) / base_radius if abs(base_radius) > 1e-12 else 1.0
)
s_length = (
float(dims[1]) / base_length if abs(base_length) > 1e-12 else 1.0
)
scale_vector = np.array([s_radius, s_radius, s_length], dtype=float)
else:
scale_vector = unit_scale
link_scale_factors[link_name] = np.array(scale_vector, dtype=float)
element_pose = rod.Pose.from_transform(
transform=np.array(hw_metadata.L_H_vis[link_index]),
relative_to=link_name,
)
scale_link_elements(
elements_to_update=elements_to_update,
scale_vector=scale_vector,
mesh_pose=element_pose,
mesh_shape_link=(shape == LinkParametrizableShape.Mesh),
)
# Update joint poses
for joint_index in range(self.number_of_joints()):
if hw_metadata.L_H_pre_mask[link_index, joint_index]:
joint_name = js.joint.idx_to_name(
model=self, joint_index=joint_index
)
if joint_name in joints_dict:
joints_dict[joint_name].pose = rod.Pose.from_transform(
transform=np.array(
hw_metadata.L_H_pre[link_index, joint_index]
),
relative_to=link_name,
)
# Propagate link scaling to descendants connected through fixed joints.
# These links are typically reduced away in the JaxSim model (e.g. feet
# attached to ankles) but still exist in the exported URDF tree.
updated = True
while updated:
updated = False
for joint in joints_dict.values():
if joint.type != "fixed":
continue
parent_scale = link_scale_factors.get(joint.parent, None)
if parent_scale is None or joint.child in link_scale_factors:
continue
link_scale_factors[joint.child] = np.array(parent_scale, dtype=float)
updated = True
# Scale fixed-joint offsets that are not part of the reduced joint set.
for joint_name, joint in joints_dict.items():
if joint.type != "fixed" or joint_name in reduced_joint_names:
continue
parent_scale = link_scale_factors.get(joint.parent, unit_scale)
if np.allclose(parent_scale, unit_scale):
continue
if joint.pose is None:
continue
transform = np.array(joint.pose.transform(), dtype=float)
transform[0:3, 3] = parent_scale * transform[0:3, 3]
joint.pose = rod.Pose.from_transform(
transform=transform,
relative_to=joint.pose.relative_to,
)
# Apply inherited scaling to non-reduced links (typically descendants
# connected via fixed joints).
for link_name, scale_vector in link_scale_factors.items():
if link_name in reduced_link_names:
continue
if np.allclose(scale_vector, unit_scale):
continue
if link_name not in links_dict:
continue
scale_link_elements(
elements_to_update=collect_link_elements(links_dict[link_name]),
scale_vector=scale_vector,
)
# Restore continuous joint types for joints with infinite limits
# to ensure valid URDF export (continuous joints should not have limits).
# Continuous joints are internally represented as revolute with infinite
# limits, but must be exported as type="continuous" for valid URDF.
for joint in joints_dict.values():
# Skip if not a revolute joint with axis and limits
if not (
joint.type == "revolute"
and joint.axis is not None
and joint.axis.limit is not None
):
continue
lower, upper = joint.axis.limit.lower, joint.axis.limit.upper
# Check if both limits are infinite (indicating original continuous joint)
if not (
lower is not None
and upper is not None
and np.isinf(lower)
and lower < 0
and np.isinf(upper)
and upper > 0
):
continue
# Restore as continuous joint
joint.type = "continuous"
# Create a new Limit object with only effort and velocity
# (no position limits for continuous joints)
joint.axis.limit = rod.Limit(
effort=joint.axis.limit.effort,
velocity=joint.axis.limit.velocity,
lower=None,
upper=None,
)
# Export the URDF string
urdf_string = UrdfExporter(pretty=True).to_urdf_string(sdf=rod_model_output)
return urdf_string
# ==========
# Properties
# ==========
def name(self) -> str:
"""
Return the name of the model.
Returns:
The name of the model.
"""
return self.model_name
def number_of_links(self) -> int:
"""
Return the number of links in the model.
Returns:
The number of links in the model.
Note:
The base link is included in the count and its index is always 0.
"""
return self.kin_dyn_parameters.number_of_links()
def number_of_joints(self) -> int:
"""
Return the number of joints in the model.
Returns:
The number of joints in the model.
"""
return self.kin_dyn_parameters.number_of_joints()
def number_of_frames(self) -> int:
"""
Return the number of frames in the model.
Returns:
The number of frames in the model.
"""
return self.kin_dyn_parameters.number_of_frames()
# =================
# Base link methods
# =================
def floating_base(self) -> bool:
"""
Return whether the model has a floating base.
Returns:
True if the model is floating-base, False otherwise.
"""
return self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6
def base_link(self) -> str:
"""
Return the name of the base link.
Returns:
The name of the base link.
Note:
By default, the base link is the root of the kinematic tree.
"""
return self.link_names()[0]
# =====================
# Joint-related methods
# =====================
def dofs(self) -> int:
"""
Return the number of degrees of freedom of the model.
Returns:
The number of degrees of freedom of the model.
Note:
We do not yet support multi-DoF joints, therefore this is always equal to
the number of joints. In the future, this could be different.
"""
return sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:])
def joint_names(self) -> tuple[str, ...]:
"""
Return the names of the joints in the model.
Returns:
The names of the joints in the model.
"""
return self.kin_dyn_parameters.joint_model.joint_names[1:]
# ====================
# Link-related methods
# ====================
def link_names(self) -> tuple[str, ...]:
"""
Return the names of the links in the model.
Returns:
The names of the links in the model.
"""
return self.kin_dyn_parameters.link_names
# =====================
# Frame-related methods
# =====================
def frame_names(self) -> tuple[str, ...]:
"""
Return the names of the frames in the model.
Returns:
The names of the frames in the model.
"""
return self.kin_dyn_parameters.frame_parameters.name
# =====================
# Model post-processing
# =====================
def reduce(
model: JaxSimModel,
considered_joints: tuple[str, ...],
locked_joint_positions: dict[str, jtp.FloatLike] | None = None,
) -> JaxSimModel:
"""
Reduce the model by lumping together the links connected by removed joints.
Args:
model: The model to reduce.
considered_joints: The sequence of joints to consider.
locked_joint_positions:
A dictionary containing the positions of the joints to be considered
in the reduction process. The removed joints in the reduced model
will have their position locked to their value of this dictionary.
If a joint is not part of the dictionary, its position is set to zero.
"""
locked_joint_positions = (
locked_joint_positions if locked_joint_positions is not None else {}
)
# If locked joints are passed, make sure that they are valid.
if not set(locked_joint_positions).issubset(model.joint_names()):
new_joints = set(model.joint_names()) - set(locked_joint_positions)
raise ValueError(f"Passed joints not existing in the model: {new_joints}")
# Operate on a deep copy of the model description in order to prevent problems
# when mutable attributes are updated.
intermediate_description = copy.deepcopy(model.description)
# Update the initial position of the joints.
# This is necessary to compute the correct pose of the link pairs connected
# to removed joints.
for joint_name in set(model.joint_names()) - set(considered_joints):
j = intermediate_description.joints_dict[joint_name]
with j.mutable_context():
j.initial_position = locked_joint_positions.get(joint_name, 0.0)
# Reduce the model description.
# If `considered_joints` contains joints not existing in the model,
# the method will raise an exception.
reduced_intermediate_description = intermediate_description.reduce(
considered_joints=list(considered_joints)
)
# Build the reduced model.
reduced_model = JaxSimModel.build(
model_description=reduced_intermediate_description,
model_name=model.name(),
time_step=model.time_step,
terrain=model.terrain,
contact_model=model.contact_model,
contact_params=model.contact_params,
actuation_params=model.actuation_params,
gravity=model.gravity,
integrator=model.integrator,
constraints=model.kin_dyn_parameters.constraints,
)
with reduced_model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
# Store the origin of the model, in case downstream logic needs it.
reduced_model.built_from = model.built_from
# Compute the hw parametrization metadata of the reduced model
# TODO: move the building of the metadata to KinDynParameters.build()
# and use the model_description instead of model.built_from.
reduced_model.kin_dyn_parameters.hw_link_metadata = (
reduced_model.compute_hw_link_metadata()
)
return reduced_model
# ===================
# Inertial properties
# ===================
@jax.jit
@js.common.named_scope
def total_mass(model: JaxSimModel) -> jtp.Float:
"""
Compute the total mass of the model.
Args:
model: The model to consider.
Returns:
The total mass of the model.
"""
return model.kin_dyn_parameters.link_parameters.mass.sum().astype(float)
@jax.jit
@js.common.named_scope
def link_spatial_inertia_matrices(model: JaxSimModel) -> jtp.Array:
"""
Compute the spatial 6D inertia matrices of all links of the model.
Args:
model: The model to consider.
Returns:
A 3D array containing the stacked spatial 6D inertia matrices of the links.
"""
return jax.vmap(js.kin_dyn_parameters.LinkParameters.spatial_inertia)(
model.kin_dyn_parameters.link_parameters
)
# ==============================
# Rigid Body Dynamics Algorithms
# ==============================
def _adjoint_from_rotation_translation(
rotation: jtp.Matrix,
translation: jtp.Vector,
) -> jtp.Matrix:
zeros = jnp.zeros_like(rotation)
top_right = jnp.einsum("...ij,...jk->...ik", Skew.wedge(translation), rotation)
return jnp.concatenate(
[
jnp.concatenate([rotation, top_right], axis=-1),
jnp.concatenate([zeros, rotation], axis=-1),
],
axis=-2,
)
def _inverse_adjoint_from_rotation_translation(
rotation: jtp.Matrix,
translation: jtp.Vector,
) -> jtp.Matrix:
rotation_t = jnp.swapaxes(rotation, -1, -2)
zeros = jnp.zeros_like(rotation_t)
top_right = -jnp.einsum("...ij,...jk->...ik", rotation_t, Skew.wedge(translation))
return jnp.concatenate(
[
jnp.concatenate([rotation_t, top_right], axis=-1),
jnp.concatenate([zeros, rotation_t], axis=-1),
],
axis=-2,
)
def _apply_input_representation_to_jacobian(
jacobian: jtp.Matrix,
base_transform: jtp.Matrix,
) -> jtp.Matrix:
transformed_base = jnp.einsum(
"...ij,jk->...ik",
jacobian[..., :, 0:6],
base_transform,
)
return jnp.concatenate([transformed_base, jacobian[..., :, 6:]], axis=-1)
def _apply_input_representation_derivative_to_jacobian(
jacobian: jtp.Matrix,
base_transform_derivative: jtp.Matrix,
) -> jtp.Matrix:
transformed_base = jnp.einsum(
"...ij,jk->...ik",
jacobian[..., :, 0:6],
base_transform_derivative,
)
return jnp.concatenate(
[transformed_base, jnp.zeros_like(jacobian[..., :, 6:])],
axis=-1,
)
def _link_jacobian_support_mask(
model: JaxSimModel,
*,
dtype: jnp.dtype,
) -> jtp.Matrix:
κb = model.kin_dyn_parameters.support_body_array_bool
return jnp.concatenate(
[
jnp.ones((model.number_of_links(), 5), dtype=dtype),
jnp.asarray(κb, dtype=dtype),
],
axis=1,
)
def _body_input_transform(
data: js.data.JaxSimModelData,
) -> tuple[jtp.Matrix, jtp.Matrix]:
base_transform = data._base_transform
base_rotation = base_transform[0:3, 0:3]
match data.velocity_representation:
case VelRepr.Inertial:
B_X_I = _inverse_adjoint_from_rotation_translation(
rotation=base_rotation,
translation=base_transform[0:3, 3],
)
B_Ẋ_I = -B_X_I @ Cross.vx(data.base_velocity)
case VelRepr.Body:
B_X_I = jnp.eye(6, dtype=base_transform.dtype)
B_Ẋ_I = jnp.zeros((6, 6), dtype=base_transform.dtype)
case VelRepr.Mixed:
B_X_I = _inverse_adjoint_from_rotation_translation(
rotation=base_rotation,
translation=jnp.zeros(3, dtype=base_transform.dtype),
)
BW_v_BW_B = data.base_velocity.at[0:3].set(
jnp.zeros(3, dtype=base_transform.dtype)
)
B_Ẋ_I = -B_X_I @ Cross.vx(BW_v_BW_B)
case _:
raise ValueError(data.velocity_representation)
return B_X_I, B_Ẋ_I
def _link_output_adjoint_from_body(
data: js.data.JaxSimModelData,
B_H_L: jtp.Matrix,
*,
output_vel_repr: VelRepr,
) -> jtp.Matrix:
base_transform = data._base_transform
base_rotation = base_transform[0:3, 0:3]
B_R_L = B_H_L[..., 0:3, 0:3]
B_p_L = B_H_L[..., 0:3, 3]
match output_vel_repr:
case VelRepr.Inertial:
return _adjoint_from_rotation_translation(
rotation=base_rotation,
translation=base_transform[0:3, 3],
)
case VelRepr.Body:
return _inverse_adjoint_from_rotation_translation(
rotation=B_R_L,
translation=B_p_L,
)
case VelRepr.Mixed:
W_p_B_in_LW = -jnp.einsum("ij,...j->...i", base_rotation, B_p_L)
W_R_B = jnp.broadcast_to(base_rotation, B_R_L.shape)
return _adjoint_from_rotation_translation(
rotation=W_R_B,
translation=W_p_B_in_LW,
)
case _:
raise ValueError(output_vel_repr)
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
def generalized_free_floating_jacobian(
model: JaxSimModel,
data: js.data.JaxSimModelData,
*,
output_vel_repr: VelRepr | None = None,
) -> jtp.Matrix:
"""
Compute the free-floating jacobians of all links.
Args:
model: The model to consider.
data: The data of the considered model.
output_vel_repr:
The output velocity representation of the free-floating jacobians.
Returns:
The `(nL, 6, 6+dofs)` array containing the stacked free-floating
jacobians of the links. The first axis is the link index.
Note:
The v-stacked version of the returned Jacobian array together with the
flattened 6D forces of the links, are useful to compute the `J.T @ f`
product of the multi-body EoM.
"""
output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)
# Compute the doubly-left free-floating full jacobian.
B_J_full_WX_B, B_H_L = jaxsim.rbda.jacobian_full_doubly_left(
model=model,
joint_positions=data.joint_positions,
)
support_mask = _link_jacobian_support_mask(model=model, dtype=B_J_full_WX_B.dtype)
B_J_WL_B = support_mask[:, jnp.newaxis, :] * B_J_full_WX_B[jnp.newaxis, ...]
B_X_I, _ = _body_input_transform(data=data)
B_J_WL_I = _apply_input_representation_to_jacobian(
jacobian=B_J_WL_B,
base_transform=B_X_I,
)
O_X_B = _link_output_adjoint_from_body(
data=data,
B_H_L=B_H_L,
output_vel_repr=output_vel_repr,
)
return jnp.einsum("...ij,...jk->...ik", O_X_B, B_J_WL_I)
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
def generalized_free_floating_jacobian_derivative(
model: JaxSimModel,
data: js.data.JaxSimModelData,
*,
output_vel_repr: VelRepr | None = None,
) -> jtp.Matrix:
"""
Compute the free-floating jacobian derivatives of all links.
Args:
model: The model to consider.
data: The data of the considered model.
output_vel_repr:
The output velocity representation of the free-floating jacobian derivatives.
Returns:
The `(nL, 6, 6+dofs)` array containing the stacked free-floating
jacobian derivatives of the links. The first axis is the link index.
"""
output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)
# Compute the derivative of the doubly-left free-floating full jacobian.
B_J̇_full_WX_B, B_H_L = jaxsim.rbda.jacobian_derivative_full_doubly_left(
model=model,
joint_positions=data.joint_positions,
joint_velocities=data.joint_velocities,
)
# The derivative of the equation to change the input and output representations
# of the Jacobian derivative needs the computation of the plain link Jacobian.
B_J_full_WL_B, _ = jaxsim.rbda.jacobian_full_doubly_left(
model=model,
joint_positions=data.joint_positions,
)
support_mask = _link_jacobian_support_mask(model=model, dtype=B_J̇_full_WX_B.dtype)
B_J̇_WL_B = support_mask[:, jnp.newaxis, :] * B_J̇_full_WX_B[jnp.newaxis, ...]
B_J_WL_B = support_mask[:, jnp.newaxis, :] * B_J_full_WL_B[jnp.newaxis, ...]
B_X_I, B_Ẋ_I = _body_input_transform(data=data)
B_J_WL_I = _apply_input_representation_to_jacobian(
jacobian=B_J_WL_B,
base_transform=B_X_I,
)
B_J̇_WL_input = _apply_input_representation_to_jacobian(
jacobian=B_J̇_WL_B,
base_transform=B_X_I,
)
B_J̇_WL_repr = _apply_input_representation_derivative_to_jacobian(
jacobian=B_J_WL_B,
base_transform_derivative=B_Ẋ_I,
)
B_v_WB = B_X_I @ data.base_velocity
B_ν = jnp.concatenate([B_v_WB, data.joint_velocities])
B_v_WL = jnp.einsum("bij,j->bi", B_J_WL_B, B_ν)
O_X_B = _link_output_adjoint_from_body(
data=data,
B_H_L=B_H_L,
output_vel_repr=output_vel_repr,
)
match output_vel_repr:
case VelRepr.Inertial:
O_Ẋ_B = O_X_B @ Cross.vx(B_v_WB)
case VelRepr.Body:
B_v_B_L = B_v_WL - B_v_WB
O_Ẋ_B = -jnp.einsum("...ij,...jk->...ik", O_X_B, Cross.vx(B_v_B_L))
case VelRepr.Mixed:
base_rotation = data._base_transform[0:3, 0:3]
B_p_L = B_H_L[..., 0:3, 3]
W_p_B_in_LW = -jnp.einsum("ij,...j->...i", base_rotation, B_p_L)
W_R_B = jnp.broadcast_to(base_rotation, B_H_L[..., 0:3, 0:3].shape)
B_X_LW = _inverse_adjoint_from_rotation_translation(
rotation=W_R_B,
translation=W_p_B_in_LW,
)
LW_v_WL = jnp.einsum("...ij,...j->...i", O_X_B, B_v_WL)
LW_v_W_LW = LW_v_WL.at[..., 3:6].set(jnp.zeros_like(LW_v_WL[..., 3:6]))
LW_v_LW_L = LW_v_WL - LW_v_W_LW
LW_v_B_LW = LW_v_WL - jnp.einsum("...ij,j->...i", O_X_B, B_v_WB) - LW_v_LW_L
O_Ẋ_B = -jnp.einsum(
"...ij,...jk->...ik",
O_X_B,
Cross.vx(jnp.einsum("...ij,...j->...i", B_X_LW, LW_v_B_LW)),
)
case _:
raise ValueError(output_vel_repr)
O_J̇_WL_I = jnp.einsum("...ij,...jk->...ik", O_Ẋ_B, B_J_WL_I)
O_J̇_WL_I += jnp.einsum("...ij,...jk->...ik", O_X_B, B_J̇_WL_input)
O_J̇_WL_I += jnp.einsum("...ij,...jk->...ik", O_X_B, B_J̇_WL_repr)
return O_J̇_WL_I
@functools.partial(jax.jit, static_argnames=["prefer_aba"])
def forward_dynamics(
model: JaxSimModel,
data: js.data.JaxSimModelData,
*,
joint_forces: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
prefer_aba: float = True,
) -> tuple[jtp.Vector, jtp.Vector]:
"""
Compute the forward dynamics of the model.
Args:
model: The model to consider.
data: The data of the considered model.
joint_forces:
The joint forces to consider as a vector of shape `(dofs,)`.
link_forces:
The link 6D forces consider as a matrix of shape `(nL, 6)`.
The frame in which they are expressed must be `data.velocity_representation`.
prefer_aba: Whether to prefer the ABA algorithm over the CRB one.
Returns:
A tuple containing the 6D acceleration in the active representation of the
base link and the joint accelerations resulting from the application of the
considered joint forces and external forces.
"""
forward_dynamics_fn = forward_dynamics_aba if prefer_aba else forward_dynamics_crb
return forward_dynamics_fn(
model=model,
data=data,
joint_forces=joint_forces,
link_forces=link_forces,
)
@functools.partial(jax.jit, static_argnames=("parallel",))
@js.common.named_scope
def forward_dynamics_aba(
model: JaxSimModel,
data: js.data.JaxSimModelData,
*,
joint_forces: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
parallel: bool = False,
) -> tuple[jtp.Vector, jtp.Vector]:
"""
Compute the forward dynamics of the model with the ABA algorithm.
Args:
model: The model to consider.
data: The data of the considered model.
joint_forces:
The joint forces to consider as a vector of shape `(dofs,)`.
link_forces:
The link 6D forces to consider as a matrix of shape `(nL, 6)`.
The frame in which they are expressed must be `data.velocity_representation`.
parallel:
If ``True``, use the level-parallel ABA implementation that
processes independent tree branches simultaneously.
Beneficial on GPU or for wide/deep kinematic trees.
Returns:
A tuple containing the 6D acceleration in the active representation of the
base link and the joint accelerations resulting from the application of the
considered joint forces and external forces.
"""
# ============
# Prepare data
# ============
# Build joint forces, if not provided.
τ = (
jnp.atleast_1d(joint_forces.squeeze())
if joint_forces is not None
else jnp.zeros_like(data.joint_positions)
)
# Build link forces, if not provided.
f_L = (
jnp.atleast_2d(link_forces.squeeze())
if link_forces is not None
else jnp.zeros((model.number_of_links(), 6))
)
# Create a references object that simplifies converting among representations.
references = js.references.JaxSimModelReferences.build(
model=model,
joint_force_references=τ,
link_forces=f_L,
data=data,
velocity_representation=data.velocity_representation,
)
# Extract the state in inertial-fixed representation.
with data.switch_velocity_representation(VelRepr.Inertial):
W_p_B = data.base_position
W_v_WB = data.base_velocity
W_Q_B = data.base_orientation
s = data.joint_positions
ṡ = data.joint_velocities
# Extract the inputs in inertial-fixed representation.
W_f_L = references._link_forces
τ = references._joint_force_references
# ========================
# Compute forward dynamics
# ========================
aba_fn = jaxsim.rbda.aba_parallel if parallel else jaxsim.rbda.aba
W_v̇_WB, s̈ = aba_fn(
model=model,
base_position=W_p_B,
base_quaternion=W_Q_B,
joint_positions=s,
base_linear_velocity=W_v_WB[0:3],
base_angular_velocity=W_v_WB[3:6],
joint_velocities=ṡ,
joint_transforms=model.kin_dyn_parameters.joint_transforms(
joint_positions=s,
base_transform=data.base_transform,
),
joint_forces=τ,
link_forces=W_f_L,
standard_gravity=model.gravity,
)
# =============
# Adjust output
# =============
def to_active(
W_v̇_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WB: jtp.Vector, W_v_WC: jtp.Vector
) -> jtp.Vector:
"""
Convert the inertial-fixed apparent base acceleration W_v̇_WB to
another representation C_v̇_WB expressed in a generic frame C.
"""
# In Mixed representation, we need to include a cross product in ℝ⁶.
# In Inertial and Body representations, the cross product is always zero.
C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
return C_X_W @ (W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB)
match data.velocity_representation:
case VelRepr.Inertial:
# In this case C=W
W_H_C = W_H_W = jnp.eye(4) # noqa: F841
W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
case VelRepr.Body:
# In this case C=B
W_H_C = W_H_B = data._base_transform
W_v_WC = W_v_WB
case VelRepr.Mixed:
# In this case C=B[W]
W_H_B = data._base_transform
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841
W_ṗ_B = data.base_velocity[0:3]
W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841
case _:
raise ValueError(data.velocity_representation)
# We need to convert the derivative of the base velocity to the active
# representation. In Mixed representation, this conversion is not a plain
# transformation with just X, but it also involves a cross product in ℝ⁶.
C_v̇_WB = to_active(
W_v̇_WB=W_v̇_WB,
W_H_C=W_H_C,
W_v_WB=W_v_WB,
W_v_WC=W_v_WC,
)
# The ABA algorithm already returns a zero base 6D acceleration for
# fixed-based models. However, the to_active function introduces an
# additional acceleration component in Mixed representation.
# Here below we make sure that the base acceleration is zero.
C_v̇_WB = C_v̇_WB if model.floating_base() else jnp.zeros(6)
return C_v̇_WB.astype(float), s̈.astype(float)
@jax.jit
@js.common.named_scope
def forward_dynamics_crb(
model: JaxSimModel,
data: js.data.JaxSimModelData,
*,
joint_forces: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
) -> tuple[jtp.Vector, jtp.Vector]:
"""
Compute the forward dynamics of the model with the CRB algorithm.
Args:
model: The model to consider.
data: The data of the considered model.
joint_forces:
The joint forces to consider as a vector of shape `(dofs,)`.
link_forces:
The link 6D forces to consider as a matrix of shape `(nL, 6)`.
The frame in which they are expressed must be `data.velocity_representation`.
Returns:
A tuple containing the 6D acceleration in the active representation of the
base link and the joint accelerations resulting from the application of the
considered joint forces and external forces.
Note:
Compared to ABA, this method could be significantly slower, especially for
models with a large number of degrees of freedom.
"""
# ============
# Prepare data
# ============
# Build joint torques if not provided.
τ = (
jnp.atleast_1d(joint_forces)
if joint_forces is not None
else jnp.zeros_like(data.joint_positions)
)
# Build external forces if not provided.
f = (
jnp.atleast_2d(link_forces)
if link_forces is not None
else jnp.zeros(shape=(model.number_of_links(), 6))
)
# Compute terms of the floating-base EoM.
M = free_floating_mass_matrix(model=model, data=data)
h = free_floating_bias_forces(model=model, data=data)
S = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T
J = generalized_free_floating_jacobian(model=model, data=data)
# TODO: invert the Mss block exploiting sparsity defined by the parent array λ(i)
# ========================
# Compute forward dynamics
# ========================
if model.floating_base():
# l: number of links.
# g: generalized coordinates, 6 + number of joints.
JTf = jnp.einsum("l6g,l6->g", J, f)
ν̇ = jnp.linalg.solve(M, S @ τ - h + JTf)
else:
# l: number of links.
# j: number of joints.
JTf = jnp.einsum("l6j,l6->j", J[:, :, 6:], f)
s̈ = jnp.linalg.solve(M[6:, 6:], τ - h[6:] + JTf)
v̇_WB = jnp.zeros(6)
ν̇ = jnp.hstack([v̇_WB, s̈.squeeze()])
# =============
# Adjust output
# =============
# Extract the base acceleration in the active representation.
# Note that this is an apparent acceleration (relevant in Mixed representation),
# therefore it cannot be always expressed in different frames with just a
# 6D transformation X.
v̇_WB = ν̇[0:6].squeeze().astype(float)
# Extract the joint accelerations.
s̈ = jnp.atleast_1d(ν̇[6:].squeeze()).astype(float)
return v̇_WB, s̈
@functools.partial(jax.jit, static_argnames=("parallel",))
@js.common.named_scope
def forward_kinematics(
model: JaxSimModel,
data: js.data.JaxSimModelData,
*,
parallel: bool = False,
) -> jtp.Matrix:
"""
Compute the forward kinematics of the model.
Args:
model: The model to consider.
data: The data of the considered model.
parallel: If True, use the level-parallel FK implementation that
processes independent tree branches simultaneously.
Returns:
The nL x 4 x 4 array containing the stacked homogeneous transformations
of the links. The first axis is the link index.
"""
fk_fn = (
jaxsim.rbda.forward_kinematics_model_parallel
if parallel
else jaxsim.rbda.forward_kinematics_model
)
# Recompute joint transforms from the model to ensure gradients
# flow through model parameters.
joint_transforms = model.kin_dyn_parameters.joint_transforms(
joint_positions=data.joint_positions,
base_transform=data.base_transform,
)
W_H_LL, _ = fk_fn(
model=model,
base_position=data.base_position,
base_quaternion=data.base_quaternion,
joint_positions=data.joint_positions,
joint_velocities=data.joint_velocities,
base_linear_velocity_inertial=data._base_linear_velocity,
base_angular_velocity_inertial=data._base_angular_velocity,
joint_transforms=joint_transforms,
)
return W_H_LL
def _transform_M_block(M_body: jtp.Matrix, X: jtp.Matrix) -> jtp.Matrix:
"""
Apply invTᵀ M_body invT with invT = diag(X, I_n), without forming invT.
Args:
M_body: (6+n, 6+n) mass matrix (inverse) in body representation.
X: (6, 6) adjoint (e.g. B_X_W or B_X_BW).
Returns:
M_repr: (6+n, 6+n) mass matrix (inverse) in the new representation.
"""
# invTᵀ M invT with invT = diag(X, I):
# Mbb' = Xᵀ Mbb X
# Mbj' = Xᵀ Mbj
# Mjb' = Mjb X
# Mjj' = Mjj
Mbb_t = X.T @ M_body[:6, :6] @ X
Mbj_t = X.T @ M_body[:6, 6:]
Mjb_t = M_body[6:, :6] @ X
Mjj_t = M_body[6:, 6:]
top = jnp.concatenate([Mbb_t, Mbj_t], axis=1)
bottom = jnp.concatenate([Mjb_t, Mjj_t], axis=1)
return jnp.concatenate([top, bottom], axis=0)
@jax.jit
@js.common.named_scope
def free_floating_mass_matrix(
model: JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
"""
Compute the free-floating mass matrix of the model with the CRBA algorithm.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The free-floating mass matrix of the model.
"""
M_body = jaxsim.rbda.crba(
model=model,
joint_positions=data.joint_positions,
)
match data.velocity_representation:
case VelRepr.Body:
return M_body
case VelRepr.Inertial:
B_X_W = Adjoint.from_transform(transform=data.base_transform, inverse=True)
return _transform_M_block(M_body, B_X_W)
case VelRepr.Mixed:
BW_H_B = data.base_transform.at[0:3, 3].set(jnp.zeros(3))
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
return _transform_M_block(M_body, B_X_BW)
case _:
raise ValueError(data.velocity_representation)
@jax.jit
@js.common.named_scope
def free_floating_mass_matrix_inverse(
model: JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
"""
Compute the inverse of the free-floating mass matrix of the model
with the CRBA algorithm.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The inverse of the free-floating mass matrix of the model.
"""
M_inv_body = jaxsim.rbda.mass_inverse(
model=model,
joint_transforms=data._joint_transforms,
)
match data.velocity_representation:
case VelRepr.Body:
return M_inv_body
case VelRepr.Inertial:
W_X_B = Adjoint.from_transform(transform=data.base_transform)
return _transform_M_block(M_inv_body, W_X_B.T)
case VelRepr.Mixed:
B_H_BW = data.base_transform.at[0:3, 3].set(jnp.zeros(3))
BW_X_B = Adjoint.from_transform(transform=B_H_BW)
return _transform_M_block(M_inv_body, BW_X_B.T)
case _:
raise ValueError(data.velocity_representation)
@jax.jit
@js.common.named_scope
def free_floating_coriolis_matrix(
model: JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
"""
Compute the free-floating Coriolis matrix of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The free-floating Coriolis matrix of the model.
Note:
This function, contrarily to other quantities of the equations of motion,
does not exploit any iterative algorithm. Therefore, the computation of
the Coriolis matrix may be much slower than other quantities.
"""
# We perform all the calculation in body-fixed.
# The Coriolis matrix computed in this representation is converted later
# to the active representation stored in data.
with data.switch_velocity_representation(VelRepr.Body):
B_ν = data.generalized_velocity
# Doubly-left free-floating Jacobian.
L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data)
# Doubly-left free-floating Jacobian derivative.
L_J̇_WL_B = generalized_free_floating_jacobian_derivative(
model=model, data=data
)
L_M_L = link_spatial_inertia_matrices(model=model)
# Body-fixed link velocities.
# Note: we could have called link.velocity() instead of computing it ourselves,
# but since we need the link Jacobians later, we can save a double calculation.
L_v_WL = jax.vmap(lambda J: J @ B_ν)(L_J_WL_B)
# Compute the contribution of each link to the Coriolis matrix.
def compute_link_contribution(M, v, J, J̇) -> jtp.Array:
return J.T @ ((Cross.vx_star(v) @ M + M @ Cross.vx(v)) @ J + M @ J̇)
C_B_links = jax.vmap(compute_link_contribution)(
L_M_L,
L_v_WL,
L_J_WL_B,
L_J̇_WL_B,
)
# We need to adjust the Coriolis matrix for fixed-base models.
# In this case, the base link does not contribute to the matrix, and we need to zero
# the off-diagonal terms mapping joint quantities onto the base configuration.
if model.floating_base():
C_B = C_B_links.sum(axis=0)
else:
C_B = C_B_links[1:].sum(axis=0)
C_B = C_B.at[0:6, 6:].set(0.0)
C_B = C_B.at[6:, 0:6].set(0.0)
# Adjust the representation of the Coriolis matrix.
# Refer to https://github.com/traversaro/traversaro-phd-thesis, Section 3.6.
match data.velocity_representation:
case VelRepr.Body:
return C_B
case VelRepr.Inertial:
n = model.dofs()
W_H_B = data._base_transform
B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True)
B_T_W = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(n))
with data.switch_velocity_representation(VelRepr.Inertial):
W_v_WB = data.base_velocity
B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)
B_Ṫ_W = jax.scipy.linalg.block_diag(B_Ẋ_W, jnp.zeros(shape=(n, n)))
with data.switch_velocity_representation(VelRepr.Body):
M = free_floating_mass_matrix(model=model, data=data)
C = B_T_W.T @ (M @ B_Ṫ_W + C_B @ B_T_W)
return C
case VelRepr.Mixed:
n = model.dofs()
BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3))
B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)
B_T_BW = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(n))
with data.switch_velocity_representation(VelRepr.Mixed):
BW_v_WB = data.base_velocity
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
BW_v_BW_B = BW_v_WB - BW_v_W_BW
B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B)
B_Ṫ_BW = jax.scipy.linalg.block_diag(B_Ẋ_BW, jnp.zeros(shape=(n, n)))
with data.switch_velocity_representation(VelRepr.Body):
M = free_floating_mass_matrix(model=model, data=data)
C = B_T_BW.T @ (M @ B_Ṫ_BW + C_B @ B_T_BW)
return C
case _:
raise ValueError(data.velocity_representation)
@jax.jit
@js.common.named_scope
def inverse_dynamics(
model: JaxSimModel,
data: js.data.JaxSimModelData,
*,
joint_accelerations: jtp.VectorLike | None = None,
base_acceleration: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
) -> tuple[jtp.Vector, jtp.Vector]:
"""
Compute inverse dynamics with the RNEA algorithm.
Args:
model: The model to consider.
data: The data of the considered model.
joint_accelerations:
The joint accelerations to consider as a vector of shape `(dofs,)`.
base_acceleration:
The base acceleration to consider as a vector of shape `(6,)`.
link_forces:
The link 6D forces to consider as a matrix of shape `(nL, 6)`.
The frame in which they are expressed must be `data.velocity_representation`.
Returns:
A tuple containing the 6D force in the active representation applied to the
base to obtain the considered base acceleration, and the joint forces to apply
to obtain the considered joint accelerations.
"""
# ============
# Prepare data
# ============
# Build joint accelerations, if not provided.
s̈ = (
jnp.atleast_1d(jnp.array(joint_accelerations).squeeze())
if joint_accelerations is not None
else jnp.zeros_like(data.joint_positions)
)
# Build base acceleration, if not provided.
v̇_WB = (
jnp.array(base_acceleration).squeeze()
if base_acceleration is not None
else jnp.zeros(6)
)
# Build link forces, if not provided.
f_L = (
jnp.atleast_2d(jnp.array(link_forces).squeeze())
if link_forces is not None
else jnp.zeros(shape=(model.number_of_links(), 6))
)
def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):
"""
Convert the active representation of the base acceleration C_v̇_WB
expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
"""
W_X_C = Adjoint.from_transform(transform=W_H_C)
C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
C_v_WC = C_X_W @ W_v_WC
# In Mixed representation, we need to include a cross product in ℝ⁶.
# In Inertial and Body representations, the cross product is always zero.
return W_X_C @ (C_v̇_WB + Cross.vx(C_v_WC) @ C_v_WB)
match data.velocity_representation:
case VelRepr.Inertial:
W_H_C = W_H_W = jnp.eye(4) # noqa: F841
W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
case VelRepr.Body:
W_H_C = W_H_B = data._base_transform
with data.switch_velocity_representation(VelRepr.Inertial):
W_v_WC = W_v_WB = data.base_velocity
case VelRepr.Mixed:
W_H_B = data._base_transform
W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841
W_ṗ_B = data.base_velocity[0:3]
W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841
case _:
raise ValueError(data.velocity_representation)
# We need to convert the derivative of the base acceleration to the Inertial
# representation. In Mixed representation, this conversion is not a plain
# transformation with just X, but it also involves a cross product in ℝ⁶.
W_v̇_WB = to_inertial(
C_v̇_WB=v̇_WB,
W_H_C=W_H_C,
C_v_WB=data.base_velocity,
W_v_WC=W_v_WC,
)
# Create a references object that simplifies converting among representations.
references = js.references.JaxSimModelReferences.build(
model=model,
data=data,
link_forces=f_L,
velocity_representation=data.velocity_representation,
)
# Extract the state in inertial-fixed representation.
with data.switch_velocity_representation(VelRepr.Inertial):
W_p_B = data.base_position
W_v_WB = data.base_velocity
W_Q_B = data.base_quaternion
s = data.joint_positions
ṡ = data.joint_velocities
# Extract the inputs in inertial-fixed representation.
W_f_L = references._link_forces
# ========================
# Compute inverse dynamics
# ========================
W_f_B, τ = jaxsim.rbda.rnea(
model=model,
base_position=W_p_B,
base_quaternion=W_Q_B,
joint_positions=s,
base_linear_velocity=W_v_WB[0:3],
base_angular_velocity=W_v_WB[3:6],
joint_velocities=ṡ,
base_linear_acceleration=W_v̇_WB[0:3],
base_angular_acceleration=W_v̇_WB[3:6],
joint_accelerations=s̈,
joint_transforms=data._joint_transforms,
link_forces=W_f_L,
standard_gravity=model.gravity,
)
# =============
# Adjust output
# =============
# Express W_f_B in the active representation.
f_B = js.data.JaxSimModelData.inertial_to_other_representation(
array=W_f_B,
other_representation=data.velocity_representation,
transform=data._base_transform,
is_force=True,
).squeeze()
return f_B.astype(float), τ.astype(float)
@jax.jit
@js.common.named_scope
def free_floating_gravity_forces(
model: JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
r"""
Compute the free-floating gravity forces :math:`g(\mathbf{q})` of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The free-floating gravity forces of the model.
"""
# Build a new state with zeroed velocities.
data_rnea = js.data.JaxSimModelData.build(
model=model,
velocity_representation=data.velocity_representation,
base_position=data.base_position,
base_quaternion=data.base_quaternion,
joint_positions=data.joint_positions,
)
return jnp.hstack(
inverse_dynamics(
model=model,
data=data_rnea,
# Set zero inputs:
joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())),
base_acceleration=jnp.zeros(6),
link_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
)
).astype(float)
@jax.jit
@js.common.named_scope
def free_floating_bias_forces(
model: JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
r"""
Compute the free-floating bias forces :math:`h(\mathbf{q}, \boldsymbol{\nu})`
of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The free-floating bias forces of the model.
"""
# Set the generalized position and generalized velocity.
base_linear_velocity, base_angular_velocity = None, None
if model.floating_base():
base_velocity = data.base_velocity
base_linear_velocity = base_velocity[:3]
base_angular_velocity = base_velocity[3:]
data_rnea = js.data.JaxSimModelData.build(
model=model,
velocity_representation=data.velocity_representation,
base_position=data.base_position,
base_quaternion=data.base_quaternion,
joint_positions=data.joint_positions,
joint_velocities=data.joint_velocities,
base_linear_velocity=base_linear_velocity,
base_angular_velocity=base_angular_velocity,
)
return jnp.hstack(
inverse_dynamics(
model=model,
data=data_rnea,
# Set zero inputs:
joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())),
base_acceleration=jnp.zeros(6),
link_forces=jnp.zeros(shape=(model.number_of_links(), 6)),
)
).astype(float)
# ==========================
# Other kinematic quantities
# ==========================
@jax.jit
@js.common.named_scope
def locked_spatial_inertia(
model: JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
"""
Compute the locked 6D inertia matrix of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The locked 6D inertia matrix of the model.
"""
return total_momentum_jacobian(model=model, data=data)[:, 0:6]
@jax.jit
@js.common.named_scope
def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
"""
Compute the total momentum of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The total momentum of the model in the active velocity representation.
"""
ν = data.generalized_velocity
Jh = total_momentum_jacobian(model=model, data=data)
return Jh @ ν
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
def total_momentum_jacobian(
model: JaxSimModel,
data: js.data.JaxSimModelData,
*,
output_vel_repr: VelRepr | None = None,
) -> jtp.Matrix:
"""
Compute the jacobian of the total momentum.
Args:
model: The model to consider.
data: The data of the considered model.
output_vel_repr: The output velocity representation of the jacobian.
Returns:
The jacobian of the total momentum of the model in the active representation.
"""
output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)
if output_vel_repr is data.velocity_representation:
return free_floating_mass_matrix(model=model, data=data)[0:6]
with data.switch_velocity_representation(VelRepr.Body):
B_Jh_B = free_floating_mass_matrix(model=model, data=data)[0:6]
match data.velocity_representation:
case VelRepr.Body:
B_Jh = B_Jh_B
case VelRepr.Inertial:
B_X_W = Adjoint.from_transform(transform=data._base_transform, inverse=True)
B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))
case VelRepr.Mixed:
BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3))
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
case _:
raise ValueError(data.velocity_representation)
match output_vel_repr:
case VelRepr.Body:
return B_Jh
case VelRepr.Inertial:
W_H_B = data._base_transform
B_Xv_W = Adjoint.from_transform(transform=W_H_B, inverse=True)
W_Xf_B = B_Xv_W.T
W_Jh = W_Xf_B @ B_Jh
return W_Jh
case VelRepr.Mixed:
BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3))
B_Xv_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
BW_Xf_B = B_Xv_BW.T
BW_Jh = BW_Xf_B @ B_Jh
return BW_Jh
case _:
raise ValueError(output_vel_repr)
@jax.jit
@js.common.named_scope
def average_velocity(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:
"""
Compute the average velocity of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The average velocity of the model computed in the base frame and expressed
in the active representation.
"""
ν = data.generalized_velocity
J = average_velocity_jacobian(model=model, data=data)
return J @ ν
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
def average_velocity_jacobian(
model: JaxSimModel,
data: js.data.JaxSimModelData,
*,
output_vel_repr: VelRepr | None = None,
) -> jtp.Matrix:
"""
Compute the Jacobian of the average velocity of the model.
Args:
model: The model to consider.
data: The data of the considered model.
output_vel_repr: The output velocity representation of the jacobian.
Returns:
The Jacobian of the average centroidal velocity of the model in the desired
representation.
"""
output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)
# Depending on the velocity representation, the frame G is either G[W] or G[B].
G_J = js.com.average_centroidal_velocity_jacobian(model=model, data=data)
match output_vel_repr:
case VelRepr.Inertial:
GW_J = G_J
W_p_CoM = js.com.com_position(model=model, data=data)
W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
W_X_GW = Adjoint.from_transform(transform=W_H_GW)
return W_X_GW @ GW_J
case VelRepr.Body:
GB_J = G_J
W_p_B = data.base_position
W_p_CoM = js.com.com_position(model=model, data=data)
B_R_W = jaxsim.math.Quaternion.to_dcm(data.base_orientation).transpose()
B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B))
B_X_GB = Adjoint.from_transform(transform=B_H_GB)
return B_X_GB @ GB_J
case VelRepr.Mixed:
GW_J = G_J
W_p_B = data.base_position
W_p_CoM = js.com.com_position(model=model, data=data)
BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B)
BW_X_GW = Adjoint.from_transform(transform=BW_H_GW)
return BW_X_GW @ GW_J
# ========================
# Other dynamic quantities
# ========================
@jax.jit
@js.common.named_scope
def link_bias_accelerations(
model: JaxSimModel,
data: js.data.JaxSimModelData,
) -> jtp.Vector:
r"""
Compute the bias accelerations of the links of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The bias accelerations of the links of the model.
Note:
This function computes the component of the total 6D acceleration not due to
the joint or base acceleration.
It is often called :math:`\dot{J} \boldsymbol{\nu}`.
"""
# ================================================
# Compute the body-fixed zero base 6D acceleration
# ================================================
# Compute the base transform.
W_H_B = data._base_transform
def other_representation_to_inertial(
C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector
) -> jtp.Vector:
"""
Convert the active representation of the base acceleration C_v̇_WB
expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.
"""
W_X_C = Adjoint.from_transform(transform=W_H_C)
C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
# In Mixed representation, we need to include a cross product in ℝ⁶.
# In Inertial and Body representations, the cross product is always zero.
return W_X_C @ (C_v̇_WB + jaxsim.math.Cross.vx(C_X_W @ W_v_WC) @ C_v_WB)
# Here we initialize a zero 6D acceleration in the active representation, and
# convert it to inertial-fixed. This is a useful intermediate representation
# because the apparent acceleration W_v̇_WB is equal to the intrinsic acceleration
# W_a_WB, and intrinsic accelerations can be expressed in different frames through
# a simple C_X_W 6D transform.
match data.velocity_representation:
case VelRepr.Inertial:
W_H_C = W_H_W = jnp.eye(4) # noqa: F841
W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
with data.switch_velocity_representation(VelRepr.Inertial):
C_v_WB = W_v_WB = data.base_velocity
case VelRepr.Body:
W_H_C = W_H_B
with data.switch_velocity_representation(VelRepr.Inertial):
W_v_WC = W_v_WB = data.base_velocity # noqa: F841
with data.switch_velocity_representation(VelRepr.Body):
C_v_WB = B_v_WB = data.base_velocity
case VelRepr.Mixed:
W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
W_H_C = W_H_BW
with data.switch_velocity_representation(VelRepr.Mixed):
W_ṗ_B = data.base_velocity[0:3]
BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)
W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW)
W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW # noqa: F841
with data.switch_velocity_representation(VelRepr.Mixed):
C_v_WB = BW_v_WB = data.base_velocity # noqa: F841
case _:
raise ValueError(data.velocity_representation)
# Convert a zero 6D acceleration from the active representation to inertial-fixed.
W_v̇_WB = other_representation_to_inertial(
C_v̇_WB=jnp.zeros(6), C_v_WB=C_v_WB, W_H_C=W_H_C, W_v_WC=W_v_WC
)
# ===================================
# Initialize buffers and prepare data
# ===================================
# Get the parent array λ(i).
# Note: λ(0) must not be used, it's initialized to -1.
λ = model.kin_dyn_parameters.parent_array
# Compute 6D transforms of the base velocity.
B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True)
# Compute the parent-to-child adjoints and the motion subspaces of the joints.
# These transforms define the relative kinematics of the entire model, including
# the base transform for both floating-base and fixed-base models.
# Ensure cached transforms stay on device when indexed with traced `i`.
i_X_λi = jnp.asarray(data._joint_transforms)
# Extract the joint motion subspaces.
S = model.kin_dyn_parameters.motion_subspaces
# Allocate the buffer to store the body-fixed link velocities.
L_v_WL = jnp.zeros(shape=(model.number_of_links(), 6))
# Store the base velocity.
with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity
L_v_WL = L_v_WL.at[0].set(B_v_WB)
# Get the joint velocities.
ṡ = data.joint_velocities
# Allocate the buffer to store the body-fixed link accelerations,
# and initialize the base acceleration.
L_v̇_WL = jnp.zeros(shape=(model.number_of_links(), 6))
L_v̇_WL = L_v̇_WL.at[0].set(B_X_W @ W_v̇_WB)
# ======================================
# Propagate accelerations and velocities
# ======================================
# The computation of the bias forces is similar to the forward pass of RNEA,
# this time with zero base and joint accelerations. Furthermore, here we do
# not remove gravity during the propagation.
# Initialize the loop.
Carry = tuple[jtp.Matrix, jtp.Matrix]
carry0: Carry = (L_v_WL, L_v̇_WL)
def propagate_accelerations(carry: Carry, i: jtp.Int) -> tuple[Carry, None]:
# Initialize index and unpack the carry.
ii = i - 1
v, a = carry
# Get the motion subspace of the joint.
Si = S[i].squeeze()
# Project the joint velocity into its motion subspace.
vJ = Si * ṡ[ii]
# Propagate the link body-fixed velocity.
v_i = i_X_λi[i] @ v[λ[i]] + vJ
v = v.at[i].set(v_i)
# Propagate the link body-fixed acceleration considering zero joint acceleration.
s̈ = 0.0
a_i = i_X_λi[i] @ a[λ[i]] + Si * s̈ + jaxsim.math.Cross.vx(v[i]) @ vJ
a = a.at[i].set(a_i)
return (v, a), None
# Compute the body-fixed velocity and body-fixed apparent acceleration of the links.
(L_v_WL, L_v̇_WL), _ = (
jax.lax.scan(
f=propagate_accelerations,
init=carry0,
xs=jnp.arange(start=1, stop=model.number_of_links()),
)
if model.number_of_links() > 1
else [(L_v_WL, L_v̇_WL), None]
)
# ===================================================================
# Convert the body-fixed 6D acceleration to the active representation
# ===================================================================
def body_to_other_representation(
L_v̇_WL: jtp.Vector, L_v_WL: jtp.Vector, C_H_L: jtp.Matrix, L_v_CL: jtp.Vector
) -> jtp.Vector:
"""
Convert the body-fixed apparent acceleration L_v̇_WL to
another representation C_v̇_WL expressed in a generic frame C.
"""
# In Mixed representation, we need to include a cross product in ℝ⁶.
# In Inertial and Body representations, the cross product is always zero.
C_X_L = jaxsim.math.Adjoint.from_transform(transform=C_H_L)
return C_X_L @ (L_v̇_WL + jaxsim.math.Cross.vx(L_v_CL) @ L_v_WL)
match data.velocity_representation:
case VelRepr.Body:
C_H_L = L_H_L = jnp.stack( # noqa: F841
[jnp.eye(4)] * model.number_of_links()
)
L_v_CL = L_v_LL = jnp.zeros( # noqa: F841
shape=(model.number_of_links(), 6)
)
case VelRepr.Inertial:
C_H_L = W_H_L = data._link_transforms
L_v_CL = L_v_WL
case VelRepr.Mixed:
W_H_L = data._link_transforms
LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L)
C_H_L = LW_H_L
L_v_CL = L_v_LW_L = jax.vmap( # noqa: F841
lambda v: v.at[0:3].set(jnp.zeros(3))
)(L_v_WL)
case _:
raise ValueError(data.velocity_representation)
# Convert from body-fixed to the active representation.
O_v̇_WL = jax.vmap(body_to_other_representation)(
L_v̇_WL=L_v̇_WL, L_v_WL=L_v_WL, C_H_L=C_H_L, L_v_CL=L_v_CL
)
return O_v̇_WL
@jax.jit
def joint_transforms(
model: JaxSimModel, joint_positions: jtp.VectorLike, base_transform: jtp.MatrixLike
) -> jtp.Array:
r"""
Return the transforms of the joints.
Args:
model: The model to consider.
joint_positions: The joint positions.
base_transform: The homogeneous matrix defining the base pose.
Returns:
The stacked transforms
:math:`{}^{i} \mathbf{H}_{\lambda(i)}(s)`
of each joint.
"""
return model.kin_dyn_parameters.joint_transforms(
joint_positions=joint_positions,
base_transform=base_transform,
)
# ======
# Energy
# ======
@jax.jit
@js.common.named_scope
def mechanical_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
"""
Compute the mechanical energy of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The mechanical energy of the model.
"""
K = kinetic_energy(model=model, data=data)
U = potential_energy(model=model, data=data)
return (K + U).astype(float)
@jax.jit
@js.common.named_scope
def kinetic_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
"""
Compute the kinetic energy of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The kinetic energy of the model.
"""
with data.switch_velocity_representation(velocity_representation=VelRepr.Body):
B_ν = data.generalized_velocity
M_B = free_floating_mass_matrix(model=model, data=data)
K = 0.5 * B_ν.T @ M_B @ B_ν
return K.squeeze().astype(float)
@jax.jit
@js.common.named_scope
def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:
"""
Compute the potential energy of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The potential energy of the model.
"""
m = total_mass(model=model)
W_p̃_CoM = jnp.hstack([js.com.com_position(model=model, data=data), 1])
return jnp.sum((m * W_p̃_CoM)[2] * model.gravity)
# ===================
# Hw parametrization
# ===================
@jax.jit
@js.common.named_scope
def update_hw_parameters(
model: JaxSimModel, scaling_factors: ScalingFactors
) -> JaxSimModel:
"""
Update the hardware parameters of the model by scaling the parameters of the links.
This function applies scaling factors to the hardware metadata of the links,
updating their shape, dimensions, density, and other related parameters. It
recalculates the mass and inertia tensors of the links based on the updated
metadata and adjusts the joint model transforms accordingly.
Args:
model: The JaxSimModel object to update.
scaling_factors: A ScalingFactors object containing scaling factors for
dimensions and density of the links.
Returns:
The updated JaxSimModel object with modified hardware parameters.
"""
kin_dyn_params: KinDynParameters = model.kin_dyn_parameters
link_parameters: LinkParameters = kin_dyn_params.link_parameters
hw_link_metadata: HwLinkMetadata = kin_dyn_params.hw_link_metadata
has_joints = model.number_of_joints() > 0
def apply_scaling_single_link(
link_shape,
geometry,
density,
L_H_G,
L_H_vis,
L_H_pre,
L_H_pre_mask,
scaling_dims,
scaling_density,
):
"""Apply scaling to a single link's numerical data."""
def scale_supported(_):
shape_indices_map = jnp.array([[0, 1, 2], [0, 0, 1], [0, 0, 0], [0, 1, 2]])
per_link_indices = shape_indices_map[link_shape]
scale_vector = scaling_dims[per_link_indices]
# Update kinematics
G_H_L = jaxsim.math.Transform.inverse(L_H_G)
G_H_vis = G_H_L @ L_H_vis
G_H̅_vis = G_H_vis.at[:3, 3].set(scale_vector * G_H_vis[:3, 3])
L_H̅_G = L_H_G.at[:3, 3].set(scale_vector * L_H_G[:3, 3])
L_H̅_vis = L_H̅_G @ G_H̅_vis
# Update shape parameters
updated_geom = geometry * scaling_dims
updated_dens = density * scaling_density
return updated_geom, updated_dens, L_H̅_G, L_H̅_vis, scale_vector
def scale_unsupported(_):
return (
geometry,
density,
L_H_G,
L_H_vis,
jnp.ones_like(scaling_dims),
)
return jax.lax.cond(
link_shape == LinkParametrizableShape.Unsupported,
scale_unsupported,
scale_supported,
operand=None,
)
# Vmap over all links for basic scaling
(
updated_geometry,
updated_density,
updated_L_H_G,
updated_L_H_vis,
scale_vectors,
) = jax.vmap(apply_scaling_single_link)(
hw_link_metadata.link_shape,
hw_link_metadata.geometry,
hw_link_metadata.density,
hw_link_metadata.L_H_G,
hw_link_metadata.L_H_vis,
hw_link_metadata.L_H_pre,
hw_link_metadata.L_H_pre_mask,
scaling_factors.dims,
scaling_factors.density,
)
# Handle joint transforms separately, only if model has joints
def transform_all_joints(operands):
"""Transform all joint poses across all links."""
original_L_H_G, updated_L_H_G, scale_vectors, L_H_pre, L_H_pre_mask = operands
# Vectorized transformation: (n_links, n_joints, 4, 4)
# Express joint transforms in the original CoM frames.
# Using the already-scaled L_H_G here introduces a second implicit
# scaling term and distorts kinematic chain proportions.
G_H_L_all = jax.vmap(jaxsim.math.Transform.inverse)(
original_L_H_G
) # (n_links, 4, 4)
# Use batch matrix multiply with broadcasting
# G_H_L_all: (n_links, 4, 4) -> (n_links, 1, 4, 4)
# L_H_pre: (n_links, n_joints, 4, 4)
# Result: (n_links, n_joints, 4, 4)
G_H_pre = G_H_L_all[:, None, :, :] @ L_H_pre
# Scale translation components
G_H̅_pre = G_H_pre.at[:, :, :3, 3].set(
jnp.where(
L_H_pre_mask[:, :, None],
scale_vectors[:, None, :] * G_H_pre[:, :, :3, 3],
G_H_pre[:, :, :3, 3],
)
)
# Transform back to link frames
# updated_L_H_G: (n_links, 4, 4) -> (n_links, 1, 4, 4)
# G_H̅_pre: (n_links, n_joints, 4, 4)
# Result: (n_links, n_joints, 4, 4)
return updated_L_H_G[:, None, :, :] @ G_H̅_pre
updated_L_H_pre = jax.lax.cond(
has_joints,
transform_all_joints,
lambda operands: operands[3], # Return L_H_pre unchanged
operand=(
hw_link_metadata.L_H_G,
updated_L_H_G,
scale_vectors,
hw_link_metadata.L_H_pre,
hw_link_metadata.L_H_pre_mask,
),
)
# Create updated HwLinkMetadata
updated_hw_link_metadata = hw_link_metadata.replace(
geometry=updated_geometry,
density=updated_density,
L_H_G=updated_L_H_G,
L_H_vis=updated_L_H_vis,
L_H_pre=updated_L_H_pre,
)
# Compute mass and inertia once and unpack the results
m_updated, I_com_updated = HwLinkMetadata.compute_mass_and_inertia(
updated_hw_link_metadata
)
# Rotate the inertia tensor at CoM with the link orientation, and store
# it in KynDynParameters.
I_L_updated = jax.vmap(
lambda metadata, I_com: metadata.L_H_G[:3, :3]
@ I_com
@ metadata.L_H_G[:3, :3].T
)(updated_hw_link_metadata, I_com_updated)
# Update link parameters
updated_link_parameters = link_parameters.replace(
mass=m_updated,
inertia_elements=jax.vmap(LinkParameters.flatten_inertia_tensor)(I_L_updated),
center_of_mass=jax.vmap(lambda metadata: metadata.L_H_G[:3, 3])(
updated_hw_link_metadata
),
)
if kin_dyn_params.contact_parameters.body:
# Compute the contact parameters
points = HwLinkMetadata.compute_contact_points(
original_contact_params=kin_dyn_params.contact_parameters,
link_shapes=updated_hw_link_metadata.link_shape,
original_com_positions=link_parameters.center_of_mass,
updated_com_positions=updated_link_parameters.center_of_mass,
scaling_factors=scaling_factors,
)
# Update contact parameters
updated_contact_parameters = kin_dyn_params.contact_parameters.replace(
point=points
)
else:
updated_contact_parameters = kin_dyn_params.contact_parameters
# Update joint model transforms (λ_H_pre)
def update_λ_H_pre(joint_index):
# Extract the transforms and masks for the current joint index across all links
L_H_pre_for_joint = updated_hw_link_metadata.L_H_pre[:, joint_index]
L_H_pre_mask_for_joint = updated_hw_link_metadata.L_H_pre_mask[:, joint_index]
# Select the first valid transform (if any) using the mask
first_valid_index = jnp.argmax(L_H_pre_mask_for_joint)
selected_transform = L_H_pre_for_joint[first_valid_index]
# Check if any valid transform exists
has_valid_transform = L_H_pre_mask_for_joint.any()
# Fallback to the original λ_H_pre if no valid transform exists
fallback_transform = kin_dyn_params.joint_model.λ_H_pre[joint_index + 1]
# Return the selected transform or fallback
return jnp.where(has_valid_transform, selected_transform, fallback_transform)
if has_joints:
# Apply the update function to all joint indices
updated_λ_H_pre = jax.vmap(update_λ_H_pre)(
jnp.arange(kin_dyn_params.number_of_joints())
)
# NOTE: λ_H_pre should be of len (1+n_joints) with the 0-th element equal
# to identity to represent the world-to-base tree transform. See JointModel class
updated_λ_H_pre_with_base = jnp.concatenate(
(jnp.eye(4).reshape(1, 4, 4), updated_λ_H_pre), axis=0
)
# Replace the joint model with the updated transforms
updated_joint_model = kin_dyn_params.joint_model.replace(
λ_H_pre=updated_λ_H_pre_with_base
)
else:
# If there are no joints, we can just use the identity transform
updated_joint_model = kin_dyn_params.joint_model
# Replace the kin_dyn_parameters with updated values
updated_kin_dyn_params = kin_dyn_params.replace(
link_parameters=updated_link_parameters,
contact_parameters=updated_contact_parameters,
hw_link_metadata=updated_hw_link_metadata,
joint_model=updated_joint_model,
)
# Return the updated model
return model.replace(kin_dyn_parameters=updated_kin_dyn_params)
# ==========
# Simulation
# ==========
@jax.jit
@js.common.named_scope
def step(
model: JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
) -> js.data.JaxSimModelData:
"""
Perform a simulation step.
Args:
model: The model to consider.
data: The data of the considered model.
dt: The time step to consider. If not specified, it is read from the model.
link_forces:
The 6D forces to apply to the links expressed in same representation of data.
joint_force_references: The joint force references to consider.
Returns:
The new data of the model after the simulation step.
Note:
In order to reduce the occurrences of frame conversions performed internally,
it is recommended to use inertial-fixed velocity representation. This can be
particularly useful for automatically differentiated logic.
"""
# TODO: some contact models here may want to perform a dynamic filtering of
# the enabled collidable points
# Extract the inputs
O_f_L_external = jnp.atleast_2d(
jnp.array(link_forces, dtype=float).squeeze()
if link_forces is not None
else jnp.zeros((model.number_of_links(), 6))
)
# Get the external forces in inertial-fixed representation.
W_f_L_external = js.data.JaxSimModelData.other_representation_to_inertial(
O_f_L_external,
other_representation=data.velocity_representation,
transform=data._link_transforms,
is_force=True,
)
τ_references = jnp.atleast_1d(
jnp.array(joint_force_references, dtype=float).squeeze()
if joint_force_references is not None
else jnp.zeros(model.dofs())
)
# ================================
# Compute the total joint torques
# ================================
τ_total = js.actuation_model.compute_resultant_torques(
model, data, joint_force_references=τ_references
)
# =============================
# Advance the simulation state
# =============================
from .integrators import _INTEGRATORS_MAP
integrator_fn = _INTEGRATORS_MAP[model.integrator]
data_tf = integrator_fn(
model=model,
data=data,
link_forces=W_f_L_external,
joint_torques=τ_total,
)
data_tf = model.contact_model.update_velocity_after_impact(
model=model, data=data_tf
)
return data_tf
================================================
FILE: src/jaxsim/api/ode.py
================================================
import jax
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import Quaternion, Skew
from jaxsim.rbda.kinematic_constraints import compute_constraint_wrenches
from .common import VelRepr
# ==================================
# Functions defining system dynamics
# ==================================
def system_acceleration(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_forces: jtp.MatrixLike | None = None,
joint_torques: jtp.VectorLike | None = None,
) -> tuple[jtp.Vector, jtp.Vector, dict[str, jtp.PyTree]]:
"""
Compute the system acceleration in the active representation.
Args:
model: The model to consider.
data: The data of the considered model.
link_forces:
The 6D forces to apply to the links expressed in the same
velocity representation of data.
joint_torques: The joint torques applied to the joints.
Returns:
A tuple containing the base 6D acceleration in the active representation,
the joint accelerations, and the contact state.
"""
# ====================
# Validate input data
# ====================
# Build link forces if not provided.
f_L = (
jnp.atleast_2d(link_forces.squeeze())
if link_forces is not None
else jnp.zeros((model.number_of_links(), 6))
).astype(float)
# ======================
# Compute contact forces
# ======================
W_f_L_terrain = jnp.zeros_like(f_L)
contact_state = data.contact_state
if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
# Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
# with the terrain.
W_f_L_terrain, contact_state_derivative = js.contact.link_contact_forces(
model=model,
data=data,
link_forces=f_L,
joint_torques=joint_torques,
)
# Update the contact state data. This is necessary only for the contact models
# that require propagation and integration of contact state.
contact_state = model.contact_model.update_contact_state(
contact_state_derivative
)
# ==================================
# Compute kinematic constraint forces
# ==================================
# Sum up all the forces: external + contact
W_f_L_total = f_L + W_f_L_terrain
# Compute the 6D forces W_f ∈ ℝ^{n_constraints × 2 × 6} applied to links due to
# kinematic constraints.
W_f_L_constraints = compute_constraint_wrenches(
model=model,
data=data,
link_forces_inertial=W_f_L_total,
joint_force_references=joint_torques,
)
# Apply constraint forces to the corresponding links
if W_f_L_constraints.shape[0] > 0:
# Get the constraint map from the model's kinematic parameters
constraint_map = model.kin_dyn_parameters.constraints
if constraint_map is not None:
# Stack the parent link indices for both sides of each constraint
parent_indices_flat = jnp.concatenate(
[constraint_map.parent_link_idxs_1, constraint_map.parent_link_idxs_2],
)
# Flatten the constraint wrenches to match the flattened parent indices
constraint_wrenches_flat = W_f_L_constraints.reshape(-1, 6)
# Apply constraint wrenches using scatter_add for better performance
W_f_L_total = W_f_L_total.at[parent_indices_flat].add(
constraint_wrenches_flat
)
# Store the link forces in a references object.
references = js.references.JaxSimModelReferences.build(
model=model,
data=data,
velocity_representation=data.velocity_representation,
link_forces=W_f_L_total,
)
# Compute forward dynamics.
#
# - Joint accelerations: s̈ ∈ ℝⁿ
# - Base acceleration: v̇_WB ∈ ℝ⁶
#
# Note that ABA returns the base acceleration in the velocity representation
# stored in the `data` object.
v̇_WB, s̈ = js.model.forward_dynamics_aba(
model=model,
data=data,
joint_forces=joint_torques,
link_forces=references.link_forces(model=model, data=data),
)
return v̇_WB, s̈, contact_state
@jax.jit
@js.common.named_scope
def system_position_dynamics(
data: js.data.JaxSimModelData,
baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
r"""
Compute the dynamics of the system position.
Args:
data: The data of the considered model.
baumgarte_quaternion_regularization:
The Baumgarte regularization coefficient for adjusting the quaternion norm.
Returns:
A tuple containing the derivative of the base position, the derivative of the
base quaternion, and the derivative of the joint positions.
Note:
In inertial-fixed representation, the linear component of the base velocity is not
the derivative of the base position. In fact, the base velocity is defined as:
:math:`{} ^W v_{W, B} = \begin{bmatrix} {} ^W \dot{p}_B S({} ^W \omega_{W, B}) {} ^W p _B\\ {} ^W \omega_{W, B} \end{bmatrix}`.
Where :math:`S(\cdot)` is the skew-symmetric matrix operator.
"""
ṡ = data.joint_velocities
W_Q_B = data.base_orientation
W_ω_WB = data.base_velocity[3:6]
W_ṗ_B = data.base_velocity[0:3] + Skew.wedge(W_ω_WB) @ data.base_position
W_Q̇_B = Quaternion.derivative(
quaternion=W_Q_B,
omega=W_ω_WB,
omega_in_body_fixed=False,
K=baumgarte_quaternion_regularization,
).squeeze()
return W_ṗ_B, W_Q̇_B, ṡ
@jax.jit
@js.common.named_scope
def system_dynamics(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_forces: jtp.Vector | None = None,
joint_torques: jtp.Vector | None = None,
baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,
) -> dict[str, jtp.Vector]:
"""
Compute the dynamics of the system.
Args:
model: The model to consider.
data: The data of the considered model.
link_forces:
The 6D forces to apply to the links expressed in the frame corresponding to
the velocity representation of `data`.
joint_torques: The joint torques acting on the joints.
baumgarte_quaternion_regularization:
The Baumgarte regularization coefficient used to adjust the norm of the
quaternion (only used in integrators not operating on the SO(3) manifold).
Returns:
A dictionary containing the derivatives of the base position, the base quaternion,
the joint positions, the base linear velocity, the base angular velocity, and the
joint velocities.
"""
with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
W_v̇_WB, s̈, contact_state_derivative = system_acceleration(
model=model,
data=data,
joint_torques=joint_torques,
link_forces=link_forces,
)
W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
data=data,
baumgarte_quaternion_regularization=baumgarte_quaternion_regularization,
)
return dict(
base_position=W_ṗ_B,
base_quaternion=W_Q̇_B,
joint_positions=ṡ,
base_linear_velocity=W_v̇_WB[0:3],
base_angular_velocity=W_v̇_WB[3:6],
joint_velocities=s̈,
contact_state=contact_state_derivative,
)
================================================
FILE: src/jaxsim/api/references.py
================================================
from __future__ import annotations
import functools
import jax
import jax.numpy as jnp
import jax_dataclasses
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim import exceptions
from jaxsim.utils.tracing import not_tracing
from .common import VelRepr
try:
from typing import Self
except ImportError:
from typing_extensions import Self
@jax_dataclasses.pytree_dataclass
class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
"""
Class containing the references for a `JaxSimModel` object.
Attributes:
_link_forces: The link 6D forces in inertial-fixed representation.
_joint_force_references: The joint force references.
"""
_link_forces: jtp.Matrix
_joint_force_references: jtp.Vector
@staticmethod
def zero(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData | None = None,
velocity_representation: VelRepr = VelRepr.Inertial,
) -> JaxSimModelReferences:
"""
Create a `JaxSimModelReferences` object with zero references.
Args:
model: The model for which to create the zero references.
data:
The data of the model, only needed if the velocity representation is
not inertial-fixed.
velocity_representation: The velocity representation to use.
Returns:
A `JaxSimModelReferences` object with zero state.
"""
return JaxSimModelReferences.build(
model=model, data=data, velocity_representation=velocity_representation
)
@staticmethod
def build(
model: js.model.JaxSimModel,
joint_force_references: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
data: js.data.JaxSimModelData | None = None,
velocity_representation: VelRepr | None = None,
) -> JaxSimModelReferences:
"""
Create a `JaxSimModelReferences` object with the given references.
Args:
model: The model for which to create the state.
joint_force_references: The joint force references.
link_forces: The link 6D forces in the desired representation.
data:
The data of the model, only needed if the velocity representation is
not inertial-fixed.
velocity_representation: The velocity representation to use.
Returns:
A `JaxSimModelReferences` object with the given references.
"""
# Create or adjust joint force references.
joint_force_references = jnp.atleast_1d(
jnp.array(joint_force_references, dtype=float).squeeze()
if joint_force_references is not None
else jnp.zeros(model.dofs())
).astype(float)
# Create or adjust link forces.
f_L = jnp.atleast_2d(
jnp.array(link_forces, dtype=float).squeeze()
if link_forces is not None
else jnp.zeros((model.number_of_links(), 6))
).astype(float)
# Select the velocity representation.
velocity_representation = (
velocity_representation
if velocity_representation is not None
else getattr(data, "velocity_representation", VelRepr.Inertial)
)
# Create a zero references object.
references = JaxSimModelReferences(
_link_forces=f_L,
_joint_force_references=joint_force_references,
velocity_representation=velocity_representation,
)
# If the velocity representation is inertial-fixed, we can return
# the references directly, as we store the link forces in this frame.
if velocity_representation is VelRepr.Inertial:
return references
# Store the joint force references.
references = references.set_joint_force_references(
forces=joint_force_references,
model=model,
joint_names=model.joint_names(),
)
# Apply the link forces.
references = references.apply_link_forces(
forces=f_L,
model=model,
data=data,
link_names=model.link_names(),
additive=False,
)
return references
def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
"""
Check if the current references are valid for the given model.
Args:
model: The model to check against.
Returns:
`True` if the current references are valid for the given model,
`False` otherwise.
"""
if model is None:
return True
shape = self._joint_force_references.shape
expected_shape = (model.dofs(),)
if shape != expected_shape:
return False
shape = self._link_forces.shape
expected_shape = (model.number_of_links(), 6)
if shape != expected_shape:
return False
return True
# ==================
# Extract quantities
# ==================
@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["link_names"])
def link_forces(
self,
model: js.model.JaxSimModel | None = None,
data: js.data.JaxSimModelData | None = None,
link_names: tuple[str, ...] | None = None,
) -> jtp.Matrix:
"""
Return the link forces expressed in the frame of the active representation.
Args:
model: The model to consider.
data: The data of the considered model.
link_names: The names of the links corresponding to the forces.
Returns:
If no model and no link names are provided, the link forces as a
`(n_links,6)` matrix corresponding to the default link serialization
of the original model used to build the actuation object.
If a model is provided and no link names are provided, the link forces
as a `(n_links,6)` matrix corresponding to the serialization of the
provided model.
If both a model and link names are provided, the link forces as a
`(len(link_names),6)` matrix corresponding to the serialization of
the passed link names vector.
Note:
The returned link forces are those passed as user inputs when integrating
the dynamics of the model. They are summed with other forces related
e.g. to the contact model and other kinematic constraints.
"""
W_f_L = self._link_forces
# Return all link forces in inertial-fixed representation using the implicit
# serialization.
if model is None:
if self.velocity_representation is not VelRepr.Inertial:
msg = "Missing model to use a representation different from {}"
raise ValueError(msg.format(VelRepr.Inertial.name))
if link_names is not None:
raise ValueError("Link names cannot be provided without a model")
return W_f_L
# If we have the model, we can extract the link names, if not provided.
link_idxs = (
js.link.names_to_idxs(link_names=link_names, model=model)
if link_names is not None
else jnp.arange(model.number_of_links())
)
# In inertial-fixed representation, we already have the link forces.
if self.velocity_representation is VelRepr.Inertial:
return W_f_L[link_idxs, :]
if data is None:
msg = "Missing model data to use a representation different from {}"
raise ValueError(msg.format(VelRepr.Inertial.name))
if not_tracing(self._link_forces) and not data.valid(model=model):
raise ValueError("The provided data is not valid for the model")
# Helper function to convert a single 6D force to the active representation
# considering as body the link (i.e. L_f_L and LW_f_L).
def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix:
return jax.vmap(
lambda W_f_L, W_H_L: JaxSimModelReferences.inertial_to_other_representation(
array=W_f_L,
other_representation=self.velocity_representation,
transform=W_H_L,
is_force=True,
)
)(W_f_L, W_H_L)
# The f_L output is either L_f_L or LW_f_L, depending on the representation.
W_H_L = data._link_transforms
f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :])
return f_L
def joint_force_references(
self,
model: js.model.JaxSimModel | None = None,
joint_names: tuple[str, ...] | None = None,
) -> jtp.Vector:
"""
Return the joint force references.
Args:
model: The model to consider.
joint_names: The names of the joints corresponding to the forces.
Returns:
If no model and no joint names are provided, the joint forces as a
`(DoFs,)` vector corresponding to the default joint serialization
of the original model used to build the actuation object.
If a model is provided and no joint names are provided, the joint forces
as a `(DoFs,)` vector corresponding to the serialization of the
provided model.
If both a model and joint names are provided, the joint forces as a
`(len(joint_names),)` vector corresponding to the serialization of
the passed joint names vector.
Note:
The returned joint forces are those passed as user inputs when integrating
the dynamics of the model. They are summed with other joint forces related
e.g. to the enforcement of other kinematic constraints. Keep also in mind
that the presence of joint friction and other similar effects can make the
actual joint forces different from the references.
"""
if model is None:
if joint_names is not None:
raise ValueError("Joint names cannot be provided without a model")
return self._joint_force_references
if not_tracing(self._joint_force_references) and not self.valid(model=model):
msg = "The actuation object is not compatible with the provided model"
raise ValueError(msg)
joint_idxs = (
js.joint.names_to_idxs(joint_names=joint_names, model=model)
if joint_names is not None
else jnp.arange(model.number_of_joints())
)
return jnp.atleast_1d(
self._joint_force_references[joint_idxs].squeeze()
).astype(float)
# ================
# Store quantities
# ================
@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["joint_names"])
def set_joint_force_references(
self,
forces: jtp.VectorLike,
model: js.model.JaxSimModel | None = None,
joint_names: tuple[str, ...] | None = None,
) -> Self:
"""
Set the joint force references.
Args:
forces: The joint force references.
model:
The model to consider, only needed if a joint serialization different
from the implicit one is used.
joint_names: The names of the joints corresponding to the forces.
Returns:
A new `JaxSimModelReferences` object with the given joint force references.
"""
forces = jnp.atleast_1d(jnp.array(forces, dtype=float).squeeze())
def replace(forces: jtp.Vector) -> JaxSimModelReferences:
return self.replace(
validate=True,
_joint_force_references=jnp.atleast_1d(forces.squeeze()).astype(float),
)
if model is None:
return replace(forces=forces)
if not_tracing(forces) and not self.valid(model=model):
msg = "The references object is not compatible with the provided model"
raise ValueError(msg)
joint_idxs = (
js.joint.names_to_idxs(joint_names=joint_names, model=model)
if joint_names is not None
else jnp.arange(model.number_of_joints())
)
return replace(forces=self._joint_force_references.at[joint_idxs].set(forces))
@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["link_names", "additive"])
def apply_link_forces(
self,
forces: jtp.MatrixLike,
model: js.model.JaxSimModel | None = None,
data: js.data.JaxSimModelData | None = None,
link_names: tuple[str, ...] | str | None = None,
additive: bool = False,
) -> Self:
"""
Apply the link forces.
Args:
forces: The link 6D forces in the active representation.
model:
The model to consider, only needed if a link serialization different
from the implicit one is used.
data:
The data of the considered model, only needed if the velocity
representation is not inertial-fixed.
link_names: The names of the links corresponding to the forces.
additive:
Whether to add the forces to the existing ones instead of replacing them.
Returns:
A new `JaxSimModelReferences` object with the given link forces.
Note:
The link forces must be expressed in the active representation.
Then, we always convert and store forces in inertial-fixed representation.
"""
f_L = jnp.atleast_2d(forces).astype(float)
# Helper function to replace the link forces.
def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences:
return self.replace(
validate=True,
_link_forces=jnp.atleast_2d(forces.squeeze()).astype(float),
)
# In this case, we allow only to set the inertial 6D forces to all links
# using the implicit link serialization.
if model is None:
if self.velocity_representation is not VelRepr.Inertial:
msg = "Missing model to use a representation different from {}"
raise ValueError(msg.format(VelRepr.Inertial.name))
if link_names is not None:
raise ValueError("Link names cannot be provided without a model")
W_f_L = f_L
W_f0_L = jnp.zeros_like(W_f_L) if not additive else self._link_forces
return replace(forces=W_f0_L + W_f_L)
if link_names is not None and len(link_names) != f_L.shape[0]:
msg = "The number of link names ({}) must match the number of forces ({})"
raise ValueError(msg.format(len(link_names), f_L.shape[0]))
# Extract the link indices.
link_idxs = (
js.link.names_to_idxs(link_names=link_names, model=model)
if link_names is not None
else jnp.arange(model.number_of_links())
)
# Compute the bias depending on whether we either set or add the link forces.
W_f0_L = (
jnp.zeros_like(f_L) if not additive else self._link_forces[link_idxs, :]
)
# If inertial-fixed representation, we can directly store the link forces.
if self.velocity_representation is VelRepr.Inertial:
W_f_L = f_L
return replace(
forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L)
)
if data is None:
msg = "Missing model data to use a representation different from {}"
raise ValueError(msg.format(VelRepr.Inertial.name))
if not_tracing(forces) and not data.valid(model=model):
raise ValueError("The provided data is not valid for the model")
W_H_L = data._link_transforms
# Convert a single 6D force to the inertial representation
# considering as body the link (i.e. L_f_L and LW_f_L).
# The f_L input is either L_f_L or LW_f_L, depending on the representation.
W_f_L = JaxSimModelReferences.other_representation_to_inertial(
array=f_L,
other_representation=self.velocity_representation,
transform=W_H_L[link_idxs] if model.number_of_links() > 1 else W_H_L,
is_force=True,
)
return replace(forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L))
def apply_frame_forces(
self,
forces: jtp.MatrixLike,
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
frame_names: tuple[str, ...] | str | None = None,
additive: bool = False,
) -> Self:
"""
Apply the frame forces.
Args:
forces: The frame 6D forces in the active representation.
model:
The model to consider, only needed if a frame serialization different
from the implicit one is used.
data:
The data of the considered model, only needed if the velocity
representation is not inertial-fixed.
frame_names: The names of the frames corresponding to the forces.
additive:
Whether to add the forces to the existing ones instead of replacing them.
Returns:
A new `JaxSimModelReferences` object with the given frame forces.
Note:
The frame forces must be expressed in the active representation.
Then, we always convert and store forces in inertial-fixed representation.
"""
f_F = jnp.atleast_2d(forces).astype(float)
if len(frame_names) != f_F.shape[0]:
msg = "The number of frame names ({}) must match the number of forces ({})"
raise ValueError(msg.format(len(frame_names), f_F.shape[0]))
# Extract the frame indices.
frame_idxs = (
js.frame.names_to_idxs(frame_names=frame_names, model=model)
if frame_names is not None
else jnp.arange(len(model.frame_names()))
)
parent_link_idxs = jnp.array(model.kin_dyn_parameters.frame_parameters.body)[
frame_idxs - model.number_of_links()
]
exceptions.raise_value_error_if(
condition=~data.valid(model=model),
msg="The provided data is not valid for the model",
)
W_H_Fi = jax.vmap(
lambda frame_idx: js.frame.transform(
model=model, data=data, frame_index=frame_idx
)
)(frame_idxs)
# Helper function to convert a single 6D force to the inertial representation
# considering as body the frame (i.e. L_f_F and LW_f_F).
def to_inertial(f_F: jtp.MatrixLike, W_H_F: jtp.MatrixLike) -> jtp.Matrix:
return JaxSimModelReferences.other_representation_to_inertial(
array=f_F,
other_representation=self.velocity_representation,
transform=W_H_F,
is_force=True,
)
match self.velocity_representation:
case VelRepr.Inertial:
W_f_F = f_F
case VelRepr.Body | VelRepr.Mixed:
W_f_F = jax.vmap(to_inertial)(f_F, W_H_Fi)
case _:
raise ValueError("Invalid velocity representation.")
# Sum the forces on the parent links.
mask = parent_link_idxs[:, jnp.newaxis] == jnp.arange(model.number_of_links())
W_f_L = mask.T @ W_f_F
with self.switch_velocity_representation(
velocity_representation=VelRepr.Inertial
):
references = self.apply_link_forces(
model=model,
data=data,
forces=W_f_L,
additive=additive,
)
with references.switch_velocity_representation(
velocity_representation=self.velocity_representation
):
return references
================================================
FILE: src/jaxsim/exceptions.py
================================================
import os
import jax
def raise_if(
condition: bool | jax.Array, exception: type, msg: str, *args, **kwargs
) -> None:
"""
Raise a host-side exception if a condition is met. Useful in jit-compiled functions.
Args:
condition:
The boolean condition of the evaluated expression that triggers
the exception during runtime.
exception: The type of exception to raise.
msg:
The message to display when the exception is raised. The message can be a
format string (fmt), whose fields are filled with the args and kwargs.
*args: The arguments to fill the format string.
**kwargs: The keyword arguments to fill the format string
"""
# Disable host callback if running on unsupported hardware or if the user
# explicitly disabled it.
if jax.devices()[0].platform in {"tpu", "METAL"} or not os.environ.get(
"JAXSIM_ENABLE_EXCEPTIONS", 0
):
return
# Check early that the format string is well-formed.
try:
_ = msg.format(*args, **kwargs)
except Exception as e:
msg = "Error in formatting exception message with args={} and kwargs={}"
raise ValueError(msg.format(args, kwargs)) from e
def _raise_exception(condition: bool, *args, **kwargs) -> None:
"""The function called by the JAX callback."""
if condition:
raise exception(msg.format(*args, **kwargs))
def _callback(args, kwargs) -> None:
"""The function that calls the JAX callback, executed only when needed."""
jax.debug.callback(_raise_exception, condition, *args, **kwargs)
# Since running a callable on the host is expensive, we prevent its execution
# if the condition is False with a low-level conditional expression.
def _run_callback_only_if_condition_is_true(*args, **kwargs) -> None:
return jax.lax.cond(
condition,
_callback,
lambda args, kwargs: None,
args,
kwargs,
)
return _run_callback_only_if_condition_is_true(*args, **kwargs)
def raise_runtime_error_if(
condition: bool | jax.Array, msg: str, *args, **kwargs
) -> None:
"""
Raise a RuntimeError if a condition is met. Useful in jit-compiled functions.
"""
return raise_if(condition, RuntimeError, msg, *args, **kwargs)
def raise_value_error_if(
condition: bool | jax.Array, msg: str, *args, **kwargs
) -> None:
"""
Raise a ValueError if a condition is met. Useful in jit-compiled functions.
"""
return raise_if(condition, ValueError, msg, *args, **kwargs)
================================================
FILE: src/jaxsim/logging.py
================================================
import enum
import inspect
import logging
import os
import warnings
import coloredlogs
class JaxSimWarning(UserWarning):
pass
_original_showwarning = warnings.showwarning
_jaxsim_root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
def pretty_jaxsim_warning(message, category, filename, lineno, file=None, line=None):
try:
caller_frame = inspect.stack()[2]
caller_file = caller_frame.filename
except Exception:
caller_file = filename
if caller_file.startswith(_jaxsim_root_dir):
print(f"\033[93m⚠️ {category.__name__}:\033[0m {message}")
print(f" → {filename}:{lineno}")
else:
_original_showwarning(message, category, filename, lineno, file, line)
# Register filter & formatter only for JaxSimWarning
# and configure it to show each warning only once
warnings.showwarning = pretty_jaxsim_warning
warnings.filterwarnings("once")
# Utility function to issue a JaxSim warning
def jaxsim_warn(msg):
warnings.warn(msg, category=JaxSimWarning, stacklevel=2)
LOGGER_NAME = "jaxsim"
class LoggingLevel(enum.IntEnum):
NOTSET = logging.NOTSET
DEBUG = logging.DEBUG
INFO = logging.INFO
WARNING = logging.WARNING
ERROR = logging.ERROR
CRITICAL = logging.CRITICAL
def _logger() -> logging.Logger:
return logging.getLogger(name=LOGGER_NAME)
def set_logging_level(level: int | LoggingLevel = LoggingLevel.WARNING):
if isinstance(level, int):
level = LoggingLevel(level)
_logger().setLevel(level=level.value)
def get_logging_level() -> LoggingLevel:
level = _logger().getEffectiveLevel()
return LoggingLevel(level)
def configure(level: LoggingLevel = LoggingLevel.WARNING) -> None:
info("Configuring the 'jaxsim' logger")
handler = logging.StreamHandler()
fmt = "%(name)s[%(process)d] %(levelname)s %(message)s"
handler.setFormatter(fmt=coloredlogs.ColoredFormatter(fmt=fmt))
_logger().addHandler(hdlr=handler)
# Do not propagate the messages to handlers of parent loggers
# (preventing duplicate logging)
_logger().propagate = False
set_logging_level(level=level)
def debug(msg: str = "") -> None:
_logger().debug(msg=msg)
def info(msg: str = "") -> None:
_logger().info(msg=msg)
def warning(msg: str = "") -> None:
_logger().warning(msg=msg)
def error(msg: str = "") -> None:
_logger().error(msg=msg)
def critical(msg: str = "") -> None:
_logger().critical(msg=msg)
def exception(msg: str = "") -> None:
_logger().exception(msg=msg)
================================================
FILE: src/jaxsim/math/__init__.py
================================================
from .adjoint import Adjoint
from .cross import Cross
from .inertia import Inertia
from .quaternion import Quaternion
from .rotation import Rotation
from .skew import Skew
from .transform import Transform
from .utils import safe_norm
from .joint_model import JointModel, supported_joint_motion # isort:skip
# Define the default standard gravity constant.
STANDARD_GRAVITY = 9.81
================================================
FILE: src/jaxsim/math/adjoint.py
================================================
import jax.numpy as jnp
import jaxlie
import jaxsim.typing as jtp
from .skew import Skew
class Adjoint:
"""
A utility class for adjoint matrix operations.
"""
@staticmethod
def from_quaternion_and_translation(
quaternion: jtp.Vector | None = None,
translation: jtp.Vector | None = None,
inverse: bool = False,
normalize_quaternion: bool = False,
) -> jtp.Matrix:
"""
Create an adjoint matrix from a quaternion and a translation.
Args:
quaternion: A quaternion vector (4D) representing orientation.
translation: A translation vector (3D).
inverse: Whether to compute the inverse adjoint.
normalize_quaternion: Whether to normalize the quaternion before creating the adjoint.
Returns:
The adjoint matrix.
"""
quaternion = quaternion if quaternion is not None else jnp.array([1.0, 0, 0, 0])
translation = translation if translation is not None else jnp.zeros(3)
assert quaternion.size == 4
assert translation.size == 3
Q_sixd = jaxlie.SO3(wxyz=quaternion)
Q_sixd = Q_sixd if not normalize_quaternion else Q_sixd.normalize()
return Adjoint.from_rotation_and_translation(
rotation=Q_sixd.as_matrix(), translation=translation, inverse=inverse
)
@staticmethod
def from_transform(transform: jtp.MatrixLike, inverse: bool = False) -> jtp.Matrix:
"""
Create an adjoint matrix from a transformation matrix.
Args:
transform: A 4x4 transformation matrix.
inverse: Whether to compute the inverse adjoint.
Returns:
The 6x6 adjoint matrix.
"""
A_H_B = transform
return (
jaxlie.SE3.from_matrix(matrix=A_H_B).adjoint()
if not inverse
else jaxlie.SE3.from_matrix(matrix=A_H_B).inverse().adjoint()
)
@staticmethod
def from_rotation_and_translation(
rotation: jtp.Matrix | None = None,
translation: jtp.Vector | None = None,
inverse: bool = False,
) -> jtp.Matrix:
"""
Create an adjoint matrix from a rotation matrix and a translation vector.
Args:
rotation: A 3x3 rotation matrix.
translation: A translation vector (3D).
inverse: Whether to compute the inverse adjoint. Default is False.
Returns:
The adjoint matrix.
"""
rotation = rotation if rotation is not None else jnp.eye(3)
translation = translation if translation is not None else jnp.zeros(3)
assert rotation.shape == (3, 3)
assert translation.size == 3
A_R_B = rotation.squeeze()
A_o_B = translation.squeeze()
if not inverse:
X = A_X_B = jnp.vstack( # noqa: F841
[
jnp.block([A_R_B, Skew.wedge(A_o_B) @ A_R_B]),
jnp.block([jnp.zeros(shape=(3, 3)), A_R_B]),
]
)
else:
X = B_X_A = jnp.vstack( # noqa: F841
[
jnp.block([A_R_B.T, -A_R_B.T @ Skew.wedge(A_o_B)]),
jnp.block([jnp.zeros(shape=(3, 3)), A_R_B.T]),
]
)
return X
@staticmethod
def to_transform(adjoint: jtp.Matrix) -> jtp.Matrix:
"""
Convert an adjoint matrix to a transformation matrix.
Args:
adjoint: The adjoint matrix (6x6).
Returns:
The transformation matrix (4x4).
"""
X = adjoint.squeeze()
assert X.shape == (6, 6)
R = X[0:3, 0:3]
o_x_R = X[0:3, 3:6]
H = jnp.vstack(
[
jnp.block([R, Skew.vee(matrix=o_x_R @ R.T)]),
jnp.array([0, 0, 0, 1]),
]
)
return H
@staticmethod
def inverse(adjoint: jtp.Matrix) -> jtp.Matrix:
"""
Compute the inverse of an adjoint matrix.
Args:
adjoint: The adjoint matrix.
Returns:
The inverse adjoint matrix.
"""
A_X_B = adjoint.reshape(-1, 6, 6)
A_R_B_T = jnp.swapaxes(A_X_B[..., 0:3, 0:3], -2, -1)
A_T_B = A_X_B[..., 0:3, 3:6]
return jnp.concatenate(
[
jnp.concatenate(
[A_R_B_T, -A_R_B_T @ A_T_B @ A_R_B_T],
axis=-1,
),
jnp.concatenate([jnp.zeros_like(A_R_B_T), A_R_B_T], axis=-1),
],
axis=-2,
).reshape(adjoint.shape)
================================================
FILE: src/jaxsim/math/cross.py
================================================
import jax.numpy as jnp
import jaxsim.typing as jtp
from .skew import Skew
class Cross:
"""
A utility class for cross product matrix operations.
"""
@staticmethod
def vx(velocity_sixd: jtp.Vector) -> jtp.Matrix:
"""
Compute the cross product matrix for 6D velocities.
Args:
velocity_sixd: A 6D velocity vector [v, ω].
Returns:
The cross product matrix (6x6).
Raises:
ValueError: If the input vector does not have a size of 6.
"""
velocity_sixd = velocity_sixd.reshape(-1, 6)
v, ω = jnp.split(velocity_sixd, 2, axis=-1)
v_cross = jnp.concatenate(
[
jnp.concatenate(
[Skew.wedge(ω), jnp.zeros((ω.shape[0], 3, 3)).squeeze()], axis=-2
),
jnp.concatenate([Skew.wedge(v), Skew.wedge(ω)], axis=-2),
],
axis=-1,
)
return v_cross
@staticmethod
def vx_star(velocity_sixd: jtp.Vector) -> jtp.Matrix:
"""
Compute the negative transpose of the cross product matrix for 6D velocities.
Args:
velocity_sixd: A 6D velocity vector [v, ω].
Returns:
The negative transpose of the cross product matrix (6x6).
Raises:
ValueError: If the input vector does not have a size of 6.
"""
v_cross_star = -Cross.vx(velocity_sixd).T
return v_cross_star
================================================
FILE: src/jaxsim/math/inertia.py
================================================
import jax.numpy as jnp
import jaxsim.typing as jtp
from .skew import Skew
class Inertia:
"""
A utility class for inertia matrix operations.
"""
@staticmethod
def to_sixd(mass: jtp.Float, com: jtp.Vector, I: jtp.Matrix) -> jtp.Matrix:
"""
Convert mass, center of mass, and inertia matrix to a 6x6 inertia matrix.
Args:
mass: The mass of the body.
com: The center of mass position (3D).
I: The 3x3 inertia matrix.
Returns:
The 6x6 inertia matrix.
Raises:
ValueError: If the shape of the inertia matrix I is not (3, 3).
"""
if I.shape != (3, 3):
raise ValueError(I, I.shape)
c = Skew.wedge(vector=com)
M = jnp.vstack(
[
jnp.block([mass * jnp.eye(3), mass * c.T]),
jnp.block([mass * c, I + mass * c @ c.T]),
]
)
return M
@staticmethod
def to_params(M: jtp.Matrix) -> tuple[jtp.Float, jtp.Vector, jtp.Matrix]:
"""
Convert a 6x6 inertia matrix to mass, center of mass, and inertia matrix.
Args:
M: The 6x6 inertia matrix.
Returns:
A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
Raises:
ValueError: If the input matrix M has an unexpected shape.
"""
m = jnp.diag(M[0:3, 0:3]).sum() / 3
mC = M[3:6, 0:3]
c = Skew.vee(mC) / m
I = M[3:6, 3:6] - (mC @ mC.T / m)
return m, c, I
================================================
FILE: src/jaxsim/math/joint_model.py
================================================
from __future__ import annotations
import jax
import jax.numpy as jnp
import jax_dataclasses
import jaxlie
from jax_dataclasses import Static
import jaxsim.typing as jtp
from jaxsim.math import Rotation
from jaxsim.parsers.descriptions import JointGenericAxis, JointType, ModelDescription
from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms
from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass
@jax_dataclasses.pytree_dataclass
class JointModel(JaxsimDataclass):
"""
Class describing the joint kinematics of a robot model.
Attributes:
λ_H_pre:
The homogeneous transformation between the parent link and
the predecessor frame of each joint.
suc_H_i:
The homogeneous transformation between the successor frame and
the child link of each joint.
joint_dofs: The number of DoFs of each joint.
joint_names: The names of each joint.
joint_types: The types of each joint.
Note:
Due to the presence of the static attributes, this class needs to be created
already in a vectorized form. In other words, it cannot be created using vmap.
"""
λ_H_pre: jtp.Array
suc_H_i: jtp.Array
joint_dofs: Static[tuple[int, ...]]
joint_names: Static[tuple[str, ...]]
joint_types: Static[tuple[int, ...]]
joint_axis: Static[tuple[JointGenericAxis, ...]]
@staticmethod
def build(description: ModelDescription) -> JointModel:
"""
Build the joint model of a model description.
Args:
description: The model description to consider.
Returns:
The joint model of the considered model description.
"""
# The link index is equal to its body index: [0, number_of_bodies - 1].
ordered_links = sorted(
list(description.links_dict.values()),
key=lambda l: l.index,
)
# Note: the joint index is equal to its child link index, therefore it
# starts from 1.
ordered_joints = sorted(
list(description.joints_dict.values()),
key=lambda j: j.index,
)
# Allocate the parent-to-predecessor and successor-to-child transforms.
λ_H_pre = jnp.zeros(shape=(1 + len(ordered_joints), 4, 4), dtype=float)
suc_H_i = jnp.zeros(shape=(1 + len(ordered_joints), 4, 4), dtype=float)
# Initialize an identical parent-to-predecessor transform for the joint
# between the world frame W and the base link B.
λ_H_pre = λ_H_pre.at[0].set(jnp.eye(4))
# Initialize the successor-to-child transform of the joint between the
# world frame W and the base link B.
# We store here the optional transform between the root frame of the model
# and the base link frame (this is needed only if the pose of the link frame
# w.r.t. the implicit __model__ SDF frame is not the identity).
suc_H_i = suc_H_i.at[0].set(ordered_links[0].pose)
# Create the object to compute forward kinematics.
fk = KinematicGraphTransforms(graph=description)
# Compute the parent-to-predecessor and successor-to-child transforms for
# each joint belonging to the model.
# Note that the joint indices starts from i=1 given our joint model,
# therefore the entries at index 0 are not updated.
for joint in ordered_joints:
λ_H_pre = λ_H_pre.at[joint.index].set(
fk.relative_transform(relative_to=joint.parent.name, name=joint.name)
)
suc_H_i = suc_H_i.at[joint.index].set(
fk.relative_transform(relative_to=joint.name, name=joint.child.name)
)
# Define the DoFs of the base link.
base_dofs = 0 if description.fixed_base else 6
# We always add a dummy fixed joint between world and base.
# TODO: Port floating-base support also at this level, not only in RBDAs.
return JointModel(
λ_H_pre=λ_H_pre,
suc_H_i=suc_H_i,
# Static attributes
joint_dofs=tuple([base_dofs] + [1 for _ in ordered_joints]),
joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]),
joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]),
joint_axis=tuple(JointGenericAxis(axis=j.axis) for j in ordered_joints),
)
def parent_H_predecessor(self, joint_index: jtp.IntLike) -> jtp.Matrix:
r"""
Return the homogeneous transformation between the parent link and
the predecessor frame of a joint.
Args:
joint_index: The index of the joint.
Returns:
The homogeneous transformation
:math:`{}^{\lambda(i)} \mathbf{H}_{\text{pre}(i)}`.
"""
return self.λ_H_pre[joint_index]
def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix:
r"""
Return the homogeneous transformation between the successor frame and
the child link of a joint.
Args:
joint_index: The index of the joint.
Returns:
The homogeneous transformation
:math:`{}^{\text{suc}(i)} \mathbf{H}_i`.
"""
return self.suc_H_i[joint_index]
@jax.jit
def supported_joint_motion(
joint_types: jtp.Array, joint_positions: jtp.Matrix, joint_axes: jtp.Matrix
) -> jtp.Matrix:
"""
Compute the transforms of the joints.
Args:
joint_types: The types of the joints.
joint_positions: The positions of the joints.
joint_axes: The axes of the joints.
Returns:
The transforms of the joints.
"""
# Prepare the joint position
s = jnp.array(joint_positions).astype(float)
def compute_F() -> tuple[jtp.Matrix, jtp.Array]:
return jaxlie.SE3.identity()
def compute_R() -> tuple[jtp.Matrix, jtp.Array]:
# Get the additional argument specifying the joint axis.
# This is a metadata required by only some joint types.
axis = jnp.array(joint_axes).astype(float).squeeze()
pre_H_suc = jaxlie.SE3.from_matrix(
matrix=jnp.eye(4).at[:3, :3].set(Rotation.from_axis_angle(vector=s * axis))
)
return pre_H_suc
def compute_P() -> tuple[jtp.Matrix, jtp.Array]:
# Get the additional argument specifying the joint axis.
# This is a metadata required by only some joint types.
axis = jnp.array(joint_axes).astype(float).squeeze()
pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.identity(),
translation=jnp.array(s * axis),
)
return pre_H_suc
return jax.lax.switch(
index=joint_types,
branches=(
compute_F, # JointType.Fixed
compute_R, # JointType.Revolute
compute_P, # JointType.Prismatic
),
).as_matrix()
================================================
FILE: src/jaxsim/math/quaternion.py
================================================
import jax.lax
import jax.numpy as jnp
import jaxlie
import jaxsim.typing as jtp
from .utils import safe_norm
class Quaternion:
"""
A utility class for quaternion operations.
"""
@staticmethod
def to_xyzw(wxyz: jtp.Vector) -> jtp.Vector:
"""
Convert a quaternion from WXYZ to XYZW representation.
Args:
wxyz: Quaternion in WXYZ representation.
Returns:
Quaternion in XYZW representation.
"""
return wxyz.squeeze()[jnp.array([1, 2, 3, 0])]
@staticmethod
def to_wxyz(xyzw: jtp.Vector) -> jtp.Vector:
"""
Convert a quaternion from XYZW to WXYZ representation.
Args:
xyzw: Quaternion in XYZW representation.
Returns:
Quaternion in WXYZ representation.
"""
return xyzw.squeeze()[jnp.array([3, 0, 1, 2])]
@staticmethod
def to_dcm(quaternion: jtp.Vector) -> jtp.Matrix:
"""
Convert a quaternion to a direction cosine matrix (DCM).
Args:
quaternion: Quaternion in XYZW representation.
Returns:
The Direction cosine matrix (DCM).
"""
return jaxlie.SO3(wxyz=quaternion).as_matrix()
@staticmethod
def from_dcm(dcm: jtp.Matrix) -> jtp.Vector:
"""
Convert a direction cosine matrix (DCM) to a quaternion.
Args:
dcm: Direction cosine matrix (DCM).
Returns:
Quaternion in WXYZ representation.
"""
return jaxlie.SO3.from_matrix(matrix=dcm).wxyz
@staticmethod
def derivative(
quaternion: jtp.Vector,
omega: jtp.Vector,
omega_in_body_fixed: bool = False,
K: float = 0.1,
) -> jtp.Vector:
"""
Compute the derivative of a quaternion given angular velocity.
Args:
quaternion: Quaternion in XYZW representation.
omega: Angular velocity vector.
omega_in_body_fixed (bool): Whether the angular velocity is in the body-fixed frame.
K (float): A scaling factor.
Returns:
The derivative of the quaternion.
"""
ω = omega.squeeze()
quaternion = quaternion.squeeze()
def Q_body(q: jtp.Vector) -> jtp.Matrix:
qw, qx, qy, qz = q
return jnp.array(
[
[qw, -qx, -qy, -qz],
[qx, qw, -qz, qy],
[qy, qz, qw, -qx],
[qz, -qy, qx, qw],
]
)
def Q_inertial(q: jtp.Vector) -> jtp.Matrix:
qw, qx, qy, qz = q
return jnp.array(
[
[qw, -qx, -qy, -qz],
[qx, qw, qz, -qy],
[qy, -qz, qw, qx],
[qz, qy, -qx, qw],
]
)
Q = jax.lax.cond(
pred=omega_in_body_fixed,
true_fun=Q_body,
false_fun=Q_inertial,
operand=quaternion,
)
norm_ω = safe_norm(ω)
qd = 0.5 * (
Q
@ jnp.hstack(
[
K * norm_ω * (1 - safe_norm(quaternion)),
ω,
]
)
)
return jnp.vstack(qd)
@staticmethod
def integration(
quaternion: jtp.VectorLike,
dt: jtp.FloatLike,
omega: jtp.VectorLike,
omega_in_body_fixed: jtp.BoolLike = False,
) -> jtp.Vector:
"""
Integrate a quaternion in SO(3) given an angular velocity.
Args:
quaternion: The quaternion to integrate.
dt: The time step.
omega: The angular velocity vector.
omega_in_body_fixed:
Whether the angular velocity is in body-fixed representation
as opposed to the default inertial-fixed representation.
Returns:
The integrated quaternion.
"""
ω_AB = jnp.array(omega).squeeze().astype(float)
A_Q_B = jnp.array(quaternion).squeeze().astype(float)
# Build the initial SO(3) quaternion.
W_Q_B_t0 = jaxlie.SO3(wxyz=A_Q_B)
# Integrate the quaternion on the manifold.
W_Q_B_tf = jax.lax.select(
pred=omega_in_body_fixed,
on_true=(W_Q_B_t0 @ jaxlie.SO3.exp(tangent=dt * ω_AB)).wxyz,
on_false=(jaxlie.SO3.exp(tangent=dt * ω_AB) @ W_Q_B_t0).wxyz,
)
return W_Q_B_tf
================================================
FILE: src/jaxsim/math/rotation.py
================================================
import jax.numpy as jnp
import jaxlie
import jaxsim.typing as jtp
from .skew import Skew
from .utils import safe_norm
class Rotation:
"""
A utility class for rotation matrix operations.
"""
@staticmethod
def x(theta: jtp.Float) -> jtp.Matrix:
"""
Generate a 3D rotation matrix around the X-axis.
Args:
theta: Rotation angle in radians.
Returns:
The 3D rotation matrix.
"""
return jaxlie.SO3.from_x_radians(theta=theta).as_matrix()
@staticmethod
def y(theta: jtp.Float) -> jtp.Matrix:
"""
Generate a 3D rotation matrix around the Y-axis.
Args:
theta: Rotation angle in radians.
Returns:
The 3D rotation matrix.
"""
return jaxlie.SO3.from_y_radians(theta=theta).as_matrix()
@staticmethod
def z(theta: jtp.Float) -> jtp.Matrix:
"""
Generate a 3D rotation matrix around the Z-axis.
Args:
theta: Rotation angle in radians.
Returns:
The 3D rotation matrix.
"""
return jaxlie.SO3.from_z_radians(theta=theta).as_matrix()
@staticmethod
def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix:
"""
Generate a 3D rotation matrix from an axis-angle representation.
Args:
vector: Axis-angle representation or the rotation as a 3D vector.
Returns:
The SO(3) rotation matrix.
"""
vector = vector.squeeze()
theta = safe_norm(vector)
s = jnp.sin(theta)
c = jnp.cos(theta)
c1 = 2 * jnp.sin(theta / 2.0) ** 2
safe_theta = jnp.where(theta == 0, 1.0, theta)
u = vector / safe_theta
u = jnp.vstack(u.squeeze())
R = c * jnp.eye(3) - s * Skew.wedge(u) + c1 * u @ u.T
return R.transpose()
@staticmethod
def log_vee(R: jnp.ndarray) -> jtp.Vector:
"""
Compute the logarithm map of an SO(3) rotation matrix.
Args:
R: The SO(3) rotation matrix.
Returns:
The corresponding so(3) tangent vector.
"""
return jaxlie.SO3.from_matrix(R).log()
================================================
FILE: src/jaxsim/math/skew.py
================================================
import jax.numpy as jnp
import jaxsim.typing as jtp
class Skew:
"""
A utility class for skew-symmetric matrix operations.
"""
@staticmethod
def wedge(vector: jtp.Vector) -> jtp.Matrix:
"""
Compute the skew-symmetric matrix (wedge operator) of a 3D vector.
Args:
vector: A 3D vector.
Returns:
The skew-symmetric matrix corresponding to the input vector.
"""
vector = vector.reshape(-1, 3)
x, y, z = jnp.split(vector, 3, axis=-1)
skew = jnp.stack(
[
jnp.concatenate([jnp.zeros_like(x), -z, y], axis=-1),
jnp.concatenate([z, jnp.zeros_like(x), -x], axis=-1),
jnp.concatenate([-y, x, jnp.zeros_like(x)], axis=-1),
],
axis=-2,
).squeeze()
return skew
@staticmethod
def vee(matrix: jtp.Matrix) -> jtp.Vector:
"""
Extract the 3D vector from a skew-symmetric matrix (vee operator).
Args:
matrix: A 3x3 skew-symmetric matrix.
Returns:
The 3D vector extracted from the input matrix.
"""
vector = 0.5 * jnp.vstack(
[
matrix[2, 1] - matrix[1, 2],
matrix[0, 2] - matrix[2, 0],
matrix[1, 0] - matrix[0, 1],
]
)
return vector
================================================
FILE: src/jaxsim/math/transform.py
================================================
import jax.numpy as jnp
import jaxlie
import jaxsim.typing as jtp
class Transform:
"""
A utility class for transformation matrix operations.
"""
@staticmethod
def from_quaternion_and_translation(
quaternion: jtp.VectorLike | None = None,
translation: jtp.VectorLike | None = None,
inverse: jtp.BoolLike = False,
normalize_quaternion: jtp.BoolLike = False,
) -> jtp.Matrix:
"""
Create a transformation matrix from a quaternion and a translation.
Args:
quaternion: A 4D vector representing a SO(3) orientation.
translation: A 3D vector representing a translation.
inverse: Whether to compute the inverse transformation.
normalize_quaternion:
Whether to normalize the quaternion before creating the transformation.
Returns:
The 4x4 transformation matrix representing the SE(3) transformation.
"""
quaternion = quaternion if quaternion is not None else jnp.array([1.0, 0, 0, 0])
translation = translation if translation is not None else jnp.zeros(3)
W_Q_B = jnp.array(quaternion).astype(float)
W_p_B = jnp.array(translation).astype(float)
assert W_p_B.shape[-1] == 3
assert W_Q_B.shape[-1] == 4
A_R_B = jaxlie.SO3(wxyz=W_Q_B)
A_R_B = A_R_B if not normalize_quaternion else A_R_B.normalize()
A_H_B = jaxlie.SE3.from_rotation_and_translation(
rotation=A_R_B, translation=W_p_B
)
return A_H_B.as_matrix() if not inverse else A_H_B.inverse().as_matrix()
@staticmethod
def from_rotation_and_translation(
rotation: jtp.MatrixLike | None = None,
translation: jtp.VectorLike | None = None,
inverse: jtp.BoolLike = False,
) -> jtp.Matrix:
"""
Create a transformation matrix from a rotation matrix and a translation vector.
Args:
rotation: A 3x3 rotation matrix representing a SO(3) orientation.
translation: A 3D vector representing a translation.
inverse: Whether to compute the inverse transformation.
Returns:
The 4x4 transformation matrix representing the SE(3) transformation.
"""
rotation = rotation if rotation is not None else jnp.eye(3)
translation = translation if translation is not None else jnp.zeros(3)
A_R_B = jnp.array(rotation).astype(float)
W_p_B = jnp.array(translation).astype(float)
assert W_p_B.size == 3
assert A_R_B.shape == (3, 3)
A_H_B = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.from_matrix(A_R_B), translation=W_p_B
)
return A_H_B.as_matrix() if not inverse else A_H_B.inverse().as_matrix()
@staticmethod
def inverse(transform: jtp.MatrixLike) -> jtp.Matrix:
"""
Compute the inverse transformation matrix.
Args:
transform: A 4x4 transformation matrix.
Returns:
The 4x4 inverse transformation matrix.
"""
return jaxlie.SE3.from_matrix(matrix=transform).inverse().as_matrix()
================================================
FILE: src/jaxsim/math/utils.py
================================================
import jax
import jax.numpy as jnp
import jaxsim.typing as jtp
def _make_safe_norm(axis, keepdims):
@jax.custom_jvp
def _safe_norm(array: jtp.ArrayLike) -> jtp.Array:
"""
Compute an array norm handling NaNs and making sure that
it is safe to get the gradient.
Args:
array: The array for which to compute the norm.
Returns:
The norm of the array with handling for zero arrays to avoid NaNs.
"""
# Compute the norm of the array along the specified axis.
return jnp.linalg.norm(array, axis=axis, keepdims=keepdims)
@_safe_norm.defjvp
def _safe_norm_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
# Check if the entire array is composed of zeros.
is_zero = jnp.allclose(x, 0)
# Replace zeros with an array of ones temporarily to avoid division by zero.
# This ensures the computation of norm does not produce NaNs or Infs.
array = jnp.where(is_zero, jnp.ones_like(x), x)
# Compute the norm of the array along the specified axis.
norm = jnp.linalg.norm(array, axis=axis, keepdims=keepdims)
dot = jnp.sum(array * x_dot, axis=axis, keepdims=keepdims)
tangent = jnp.where(is_zero, 0.0, dot / norm)
return jnp.where(is_zero, 0.0, norm), tangent
return _safe_norm
def safe_norm(array: jtp.ArrayLike, *, axis=None, keepdims: bool = False) -> jtp.Array:
"""
Compute an array norm handling NaNs and making sure that
it is safe to get the gradient.
Args:
array: The array for which to compute the norm.
axis: The axis for which to compute the norm.
keepdims: Whether to keep the dimensions of the input
Returns:
The norm of the array with handling for zero arrays to avoid NaNs.
"""
return _make_safe_norm(axis, keepdims)(array)
================================================
FILE: src/jaxsim/mujoco/__init__.py
================================================
from .loaders import ModelToMjcf, RodModelToMjcf, SdfToMjcf, UrdfToMjcf
from .model import MujocoModelHelper
from .utils import MujocoCamera, mujoco_data_from_jaxsim
from .visualizer import MujocoVideoRecorder, MujocoVisualizer
================================================
FILE: src/jaxsim/mujoco/__main__.py
================================================
import argparse
import pathlib
import sys
import time
import numpy as np
from . import ModelToMjcf, MujocoModelHelper, MujocoVisualizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="jaxsim.mujoco",
description="Process URDF and SDF files for Mujoco usage.",
)
parser.add_argument(
"-d",
"--description",
required=True,
metavar="INPUT_FILE",
type=pathlib.Path,
help="Path to the URDF or SDF file.",
)
parser.add_argument(
"-m",
"--model-name",
metavar="NAME",
type=str,
default=None,
help="The target model of a SDF description if multiple models exists.",
)
parser.add_argument(
"-e",
"--export",
metavar="MJCF_FILE",
type=pathlib.Path,
default=None,
help="Path to the exported MJCF file.",
)
parser.add_argument(
"-f",
"--force",
action="store_true",
default=False,
help="Override the output MJCF file if it already exists (default: %(default)s).",
)
parser.add_argument(
"-p",
"--print",
action="store_true",
default=False,
help="Print in the stdout the exported MJCF string (default: %(default)s).",
)
parser.add_argument(
"-v",
"--visualize",
action="store_true",
default=False,
help="Visualize the description in the Mujoco viewer (default: %(default)s).",
)
parser.add_argument(
"-b",
"--base-position",
metavar=("x", "y", "z"),
nargs=3,
type=float,
default=None,
help="Override the base position (supports only floating-base models).",
)
parser.add_argument(
"-q",
"--base-quaternion",
metavar=("w", "x", "y", "z"),
nargs=4,
type=float,
default=None,
help="Override the base quaternion (supports only floating-base models).",
)
args = parser.parse_args()
# ==================
# Validate arguments
# ==================
# Expand the path of the URDF/SDF file if not absolute.
if args.description is not None:
args.description = (
(
args.description
if args.description.is_absolute()
else pathlib.Path.cwd() / args.description
)
.expanduser()
.absolute()
)
if not pathlib.Path(args.description).is_file():
msg = f"The URDF/SDF file '{args.description}' does not exist."
parser.error(msg)
sys.exit(1)
# Expand the path of the output MJCF file if not absolute.
if args.export is not None:
args.export = (
(
args.export
if args.export.is_absolute()
else pathlib.Path.cwd() / args.export
)
.expanduser()
.absolute()
)
if pathlib.Path(args.export).is_file() and not args.force:
msg = "The output file '{}' already exists, use '--force' to override."
parser.error(msg.format(args.export))
sys.exit(1)
# ================================================
# Load the URDF/SDF file and produce a MJCF string
# ================================================
mjcf_string, assets = ModelToMjcf.convert(args.description)
if args.print:
print(mjcf_string, flush=True)
# ========================================
# Write the MJCF string to the output file
# ========================================
if args.export is not None:
with open(args.export, "w+", encoding="utf-8") as file:
file.write(mjcf_string)
# =======================================
# Visualize the MJCF in the Mujoco viewer
# =======================================
if args.visualize:
mj_model_helper = MujocoModelHelper.build_from_xml(
mjcf_description=mjcf_string, assets=assets
)
viz = MujocoVisualizer(model=mj_model_helper.model, data=mj_model_helper.data)
with viz.open() as viewer:
with viewer.lock():
if args.base_position is not None:
mj_model_helper.set_base_position(
position=np.array(args.base_position)
)
if args.base_quaternion is not None:
mj_model_helper.set_base_orientation(
orientation=np.array(args.base_quaternion)
)
viz.sync(viewer=viewer)
while viewer.is_running():
time.sleep(0.500)
# =============================
# Exit the program with success
# =============================
sys.exit(0)
================================================
FILE: src/jaxsim/mujoco/loaders.py
================================================
import pathlib
import tempfile
import warnings
from collections.abc import Sequence
from typing import Any
import jaxlie
import mujoco as mj
import numpy as np
import rod.urdf.exporter
from lxml import etree as ET
from jaxsim import logging
from .utils import MujocoCamera
MujocoCameraType = (
MujocoCamera | Sequence[MujocoCamera] | dict[str, str] | Sequence[dict[str, str]]
)
def load_rod_model(
model_description: str | pathlib.Path | rod.Model,
is_urdf: bool | None = None,
model_name: str | None = None,
) -> rod.Model:
"""
Load a ROD model from a URDF/SDF file or a ROD model.
Args:
model_description: The URDF/SDF file or ROD model to load.
is_urdf: Whether to force parsing the model description as a URDF file.
model_name: The name of the model to load from the resource.
Returns:
rod.Model: The loaded ROD model.
"""
# Parse the SDF resource.
sdf_element = rod.Sdf.load(sdf=model_description, is_urdf=is_urdf)
# Fail if the SDF resource has no model.
if len(sdf_element.models()) == 0:
raise RuntimeError("Failed to find any model in the model description")
# Return the model if there is only one.
if len(sdf_element.models()) == 1:
if model_name is not None and sdf_element.models()[0].name != model_name:
raise ValueError(f"Model '{model_name}' not found in the description")
return sdf_element.models()[0]
# Require users to specify the model name if there are multiple models.
if model_name is None:
msg = "The resource has multiple models. Please specify the model name."
raise ValueError(msg)
# Build a dictionary of models in the resource for easy access.
models = {m.name: m for m in sdf_element.models()}
if model_name not in models:
raise ValueError(f"Model '{model_name}' not found in the resource")
return models[model_name]
class ModelToMjcf:
"""
Class to convert a URDF/SDF file or a ROD model to a Mujoco MJCF string.
"""
@staticmethod
def convert(
model: str | pathlib.Path | rod.Model,
considered_joints: list[str] | None = None,
plane_normal: tuple[float, float, float] = (0, 0, 1),
heightmap: bool | None = None,
heightmap_samples_xy: tuple[int, int] = (101, 101),
cameras: MujocoCameraType = (),
) -> tuple[str, dict[str, Any]]:
"""
Convert a model to a Mujoco MJCF string.
Args:
model: The URDF/SDF file or ROD model to convert.
considered_joints: The list of joint names to consider in the conversion.
plane_normal: The normal vector of the plane.
heightmap: Whether to generate a heightmap.
heightmap_samples_xy: The number of points in the heightmap grid.
cameras: The custom cameras to add to the scene.
Returns:
A tuple containing the MJCF string and the dictionary of assets.
"""
match model:
case rod.Model():
rod_model = model
case str() | pathlib.Path():
# Convert the JaxSim model to a ROD model.
rod_model = load_rod_model(
model_description=model,
is_urdf=None,
model_name=None,
)
case _:
raise TypeError(f"Unsupported type for 'model': {type(model)}")
# Convert the ROD model to MJCF.
return RodModelToMjcf.convert(
rod_model=rod_model,
considered_joints=considered_joints,
plane_normal=plane_normal,
heightmap=heightmap,
heightmap_samples_xy=heightmap_samples_xy,
cameras=cameras,
)
class RodModelToMjcf:
"""
Class to convert a ROD model to a Mujoco MJCF string.
"""
@staticmethod
def assets_from_rod_model(
rod_model: rod.Model,
) -> dict[str, bytes]:
"""
Generate a dictionary of assets from a ROD model.
Args:
rod_model: The ROD model to extract the assets from.
Returns:
dict: A dictionary of assets.
"""
import resolve_robotics_uri_py
assets_files = dict()
for link in rod_model.links():
for visual in link.visuals():
if visual.geometry.mesh and visual.geometry.mesh.uri:
assets_files[visual.geometry.mesh.uri] = (
resolve_robotics_uri_py.resolve_robotics_uri(
visual.geometry.mesh.uri
)
)
for collision in link.collisions():
if collision.geometry.mesh and collision.geometry.mesh.uri:
assets_files[collision.geometry.mesh.uri] = (
resolve_robotics_uri_py.resolve_robotics_uri(
collision.geometry.mesh.uri
)
)
assets = {
asset_name: asset.read_bytes() for asset_name, asset in assets_files.items()
}
return assets
@staticmethod
def add_floating_joint(
urdf_string: str,
base_link_name: str,
floating_joint_name: str = "world_to_base",
) -> str:
"""
Add a floating joint to a URDF string.
Args:
urdf_string: The URDF string to modify.
base_link_name: The name of the base link to attach the floating joint.
floating_joint_name: The name of the floating joint to add.
Returns:
str: The modified URDF string.
"""
with tempfile.NamedTemporaryFile(mode="w+", suffix=".urdf") as urdf_file:
# Write the URDF string to a temporary file and move current position
# to the beginning.
urdf_file.write(urdf_string)
urdf_file.seek(0)
# Parse the MJCF string as XML (etree).
parser = ET.XMLParser(remove_blank_text=True)
tree = ET.parse(source=urdf_file, parser=parser)
root: ET._Element = tree.getroot()
if root.find(f".//joint[@name='{floating_joint_name}']") is not None:
msg = f"The URDF already has a floating joint '{floating_joint_name}'"
warnings.warn(msg, stacklevel=2)
return ET.tostring(root, pretty_print=True).decode()
# Create the "world" link if it doesn't exist.
if root.find(".//link[@name='world']") is None:
_ = ET.SubElement(root, "link", name="world")
# Create the floating joint.
world_to_base = ET.SubElement(
root, "joint", name=floating_joint_name, type="floating"
)
# Check that the base link exists.
if root.find(f".//link[@name='{base_link_name}']") is None:
raise ValueError(f"Link '{base_link_name}' not found in the URDF")
# Attach the floating joint to the base link.
ET.SubElement(world_to_base, "parent", link="world")
ET.SubElement(world_to_base, "child", link=base_link_name)
urdf_string = ET.tostring(root, pretty_print=True).decode()
return urdf_string
@staticmethod
def convert(
rod_model: rod.Model,
considered_joints: list[str] | None = None,
plane_normal: tuple[float, float, float] = (0, 0, 1),
heightmap: bool | None = None,
heightmap_samples_xy: tuple[int, int] = (101, 101),
cameras: MujocoCameraType = (),
) -> tuple[str, dict[str, Any]]:
"""
Convert a ROD model to a Mujoco MJCF string.
Args:
rod_model: The ROD model to convert.
considered_joints: The list of joint names to consider in the conversion.
plane_normal: The normal vector of the plane.
heightmap: Whether to generate a heightmap.
heightmap_samples_xy: The number of points in the heightmap grid.
cameras: The custom cameras to add to the scene.
Returns:
A tuple containing the MJCF string and the dictionary of assets.
"""
# -------------------------------------
# Convert the model description to URDF
# -------------------------------------
# Consider all joints if not specified otherwise.
considered_joints = set(
considered_joints
if considered_joints is not None
else [j.name for j in rod_model.joints() if j.type != "fixed"]
)
# If considered joints are passed, make sure that they are all part of the model.
if considered_joints - {j.name for j in rod_model.joints()}:
extra_joints = considered_joints - {j.name for j in rod_model.joints()}
msg = f"Couldn't find the following joints in the model: '{extra_joints}'"
raise ValueError(msg)
# Create a dictionary of joints for quick access.
joints_dict = {j.name: j for j in rod_model.joints()}
# Convert all the joints not considered to fixed joints.
for joint_name in {j.name for j in rod_model.joints()} - considered_joints:
joints_dict[joint_name].type = "fixed"
# Convert the ROD model to URDF.
urdf_string = rod.urdf.exporter.UrdfExporter(
gazebo_preserve_fixed_joints=False, pretty=True
).to_urdf_string(
sdf=rod.Sdf(model=rod_model, version="1.7"),
)
# -------------------------------------
# Add a floating joint if floating-base
# -------------------------------------
base_link_name = rod_model.get_canonical_link()
if not rod_model.is_fixed_base():
considered_joints |= {"world_to_base"}
urdf_string = RodModelToMjcf.add_floating_joint(
urdf_string=urdf_string,
base_link_name=base_link_name,
floating_joint_name="world_to_base",
)
# ---------------------------------------
# Inject the element in the URDF
# ---------------------------------------
parser = ET.XMLParser(remove_blank_text=True)
root = ET.fromstring(text=urdf_string.encode(), parser=parser)
mujoco_element = (
ET.SubElement(root, "mujoco")
if len(root.findall("./mujoco")) == 0
else root.find("./mujoco")
)
_ = ET.SubElement(
mujoco_element,
"compiler",
balanceinertia="true",
discardvisual="false",
)
urdf_string = ET.tostring(root, pretty_print=True).decode()
# ------------------------------
# Post-process all dummy visuals
# ------------------------------
parser = ET.XMLParser(remove_blank_text=True)
root: ET._Element = ET.fromstring(text=urdf_string.encode(), parser=parser)
# Give a tiny radius to all dummy spheres
for geometry in root.findall(".//visual/geometry[sphere]"):
radius = np.fromstring(
geometry.find("./sphere").attrib["radius"], sep=" ", dtype=float
)
if np.allclose(radius, np.zeros(1)):
geometry.find("./sphere").set("radius", "0.001")
# Give a tiny volume to all dummy boxes
for geometry in root.findall(".//visual/geometry[box]"):
size = np.fromstring(
geometry.find("./box").attrib["size"], sep=" ", dtype=float
)
if np.allclose(size, np.zeros(3)):
geometry.find("./box").set("size", "0.001 0.001 0.001")
urdf_string = ET.tostring(root, pretty_print=True).decode()
# ------------------------
# Convert the URDF to MJCF
# ------------------------
# Load the URDF model into Mujoco.
assets = RodModelToMjcf.assets_from_rod_model(rod_model=rod_model)
mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets)
# Get the joint names.
mj_joint_names = {
mj.mj_id2name(mj_model, mj.mjtObj.mjOBJ_JOINT, idx)
for idx in range(mj_model.njnt)
}
# Check that the Mujoco model only has the considered joints.
if mj_joint_names != considered_joints:
extra1 = mj_joint_names - considered_joints
extra2 = considered_joints - mj_joint_names
extra_joints = extra1.union(extra2)
msg = "The Mujoco model has the following extra/missing joints: '{}'"
raise ValueError(msg.format(extra_joints))
# Windows locks open files, so we use mkstemp() to create a temporary file without keeping it open.
with tempfile.NamedTemporaryFile(
suffix=".xml", prefix=f"{rod_model.name}_", delete=False
) as tmp:
temp_filename = tmp.name
try:
# Convert the in-memory Mujoco model to MJCF.
mj.mj_saveLastXML(temp_filename, mj_model)
# Parse the MJCF file as XML.
parser = ET.XMLParser(remove_blank_text=True)
tree = ET.parse(source=temp_filename, parser=parser)
finally:
pathlib.Path(temp_filename).unlink()
# Get the root element.
root: ET._Element = tree.getroot()
# Find the element (might be the root itself).
mujoco_element: ET._Element = next(iter(root.iter("mujoco")))
# --------------
# Add the frames
# --------------
for frame in rod_model.frames():
frame: rod.Frame
parent_name = frame.attached_to
parent_element = mujoco_element.find(f".//body[@name='{parent_name}']")
if parent_element is None and parent_name == base_link_name:
parent_element = mujoco_element.find(".//worldbody")
if parent_element is not None:
quat = jaxlie.SO3.from_rpy_radians(*frame.pose.rpy).wxyz
_ = ET.SubElement(
parent_element,
"site",
name=frame.name,
pos=" ".join(map(str, frame.pose.xyz)),
quat=" ".join(map(str, quat)),
)
else:
logging.debug(f"Parent link '{parent_name}' not found")
# --------------
# Add the motors
# --------------
if len(mujoco_element.findall(".//actuator")) > 0:
raise RuntimeError("The model already has elements.")
# Add the actuator element.
actuator_element = ET.SubElement(mujoco_element, "actuator")
# Add a motor for each joint.
for joint_element in mujoco_element.findall(".//joint"):
assert (
joint_element.attrib["name"] in considered_joints
), joint_element.attrib["name"]
if joint_element.attrib.get("type", "hinge") in {"free", "ball"}:
continue
ET.SubElement(
actuator_element,
"motor",
name=f"{joint_element.attrib['name']}_motor",
joint=joint_element.attrib["name"],
gear="1",
)
# ---------------------------------------------
# Set full transparency of collision geometries
# ---------------------------------------------
parser = ET.XMLParser(remove_blank_text=True)
# Get all the (optional) names of the URDF collision elements
collision_names = {
c.attrib["name"]
for c in ET.fromstring(text=urdf_string.encode(), parser=parser).findall(
".//collision[geometry]"
)
if "name" in c.attrib
}
# Set alpha=0 to the color of all collision elements
for geometry_element in mujoco_element.findall(".//geom[@rgba]"):
if geometry_element.attrib.get("name") in collision_names:
r, g, b, _ = geometry_element.attrib["rgba"].split(" ")
geometry_element.set("rgba", f"{r} {g} {b} 0")
# -----------------------
# Create the scene assets
# -----------------------
asset_element = (
ET.SubElement(mujoco_element, "asset")
if len(mujoco_element.findall(".//asset")) == 0
else mujoco_element.find(".//asset")
)
_ = ET.SubElement(
asset_element,
"texture",
type="skybox",
builtin="gradient",
rgb1="0.3 0.5 0.7",
rgb2="0 0 0",
width="512",
height="512",
)
_ = ET.SubElement(
asset_element,
"texture",
name="plane_texture",
type="2d",
builtin="checker",
rgb1="0.1 0.2 0.3",
rgb2="0.2 0.3 0.4",
width="512",
height="512",
mark="cross",
markrgb=".8 .8 .8",
)
_ = ET.SubElement(
asset_element,
"material",
name="plane_material",
texture="plane_texture",
reflectance="0.2",
texrepeat="5 5",
texuniform="true",
)
_ = (
ET.SubElement(
asset_element,
"hfield",
name="terrain",
nrow=str(int(heightmap_samples_xy[0])),
ncol=str(int(heightmap_samples_xy[1])),
# The following 'size' is a placeholder, it is updated dynamically
# when a hfield/heightmap is stored into MjData.
size="1 1 1 1",
)
if heightmap
else None
)
# ----------------------------------
# Populate the scene with the assets
# ----------------------------------
worldbody_scene_element = ET.SubElement(mujoco_element, "worldbody")
_ = ET.SubElement(
worldbody_scene_element,
"geom",
name="floor",
type="plane" if not heightmap else "hfield",
size="0 0 0.05",
material="plane_material",
condim="3",
contype="1",
conaffinity="1",
zaxis=" ".join(map(str, plane_normal)),
**({"hfield": "terrain"} if heightmap else {}),
)
_ = ET.SubElement(
worldbody_scene_element,
"light",
name="sun",
mode="fixed",
directional="true",
castshadow="true",
pos="0 0 10",
dir="0 0 -1",
)
# -------------------------------------------------------
# Add a camera following the CoM of the worldbody element
# -------------------------------------------------------
worldbody_element = None
# Find the element of our model by searching the one that contains
# all the considered joints. This is needed because there might be multiple
# elements inside .
for wb in mujoco_element.findall(".//worldbody"):
if all(
wb.find(f".//joint[@name='{j}']") is not None for j in considered_joints
):
worldbody_element = wb
break
if worldbody_element is None:
raise RuntimeError("Failed to find the element of the model")
# Camera attached to the model
# It can be manually copied from `python -m mujoco.viewer --mjcf=`
_ = ET.SubElement(
worldbody_element,
"camera",
name="track",
mode="trackcom",
pos="1.930 -2.279 0.556",
xyaxes="0.771 0.637 0.000 -0.116 0.140 0.983",
fovy="60",
)
# Add user-defined camera.
for camera in cameras if isinstance(cameras, Sequence) else [cameras]:
mj_camera = (
camera
if isinstance(camera, MujocoCamera)
else MujocoCamera.build(**camera)
)
_ = ET.SubElement(worldbody_element, "camera", mj_camera.asdict())
# ------------------------------------------------
# Add a light following the CoM of the first link
# ------------------------------------------------
if not rod_model.is_fixed_base():
# Light attached to the model
_ = ET.SubElement(
worldbody_element,
"light",
name="light_model",
mode="targetbodycom",
target=worldbody_element.find(".//body").attrib["name"],
directional="false",
castshadow="true",
pos="1 0 5",
)
# --------------------------------
# Return the resulting MJCF string
# --------------------------------
mjcf_string = ET.tostring(root, pretty_print=True).decode()
return mjcf_string, assets
class UrdfToMjcf:
"""
Class to convert a URDF file to a Mujoco MJCF string.
"""
@staticmethod
def convert(
urdf: str | pathlib.Path,
considered_joints: list[str] | None = None,
model_name: str | None = None,
plane_normal: tuple[float, float, float] = (0, 0, 1),
heightmap: bool | None = None,
cameras: MujocoCameraType = (),
) -> tuple[str, dict[str, Any]]:
"""
Convert a URDF file to a Mujoco MJCF string.
Args:
urdf: The URDF file to convert.
considered_joints: The list of joint names to consider in the conversion.
model_name: The name of the model to convert.
plane_normal: The normal vector of the plane.
heightmap: Whether to generate a heightmap.
cameras: The list of cameras to add to the scene.
Returns:
tuple: A tuple containing the MJCF string and the assets dictionary.
"""
logging.warning("This method is deprecated. Use 'ModelToMjcf.convert' instead.")
# Get the ROD model.
rod_model = load_rod_model(
model_description=urdf,
is_urdf=True,
model_name=model_name,
)
# Convert the ROD model to MJCF.
return RodModelToMjcf.convert(
rod_model=rod_model,
considered_joints=considered_joints,
plane_normal=plane_normal,
heightmap=heightmap,
cameras=cameras,
)
class SdfToMjcf:
"""
Class to convert a SDF file to a Mujoco MJCF string.
"""
@staticmethod
def convert(
sdf: str | pathlib.Path,
considered_joints: list[str] | None = None,
model_name: str | None = None,
plane_normal: tuple[float, float, float] = (0, 0, 1),
heightmap: bool | None = None,
cameras: MujocoCameraType = (),
) -> tuple[str, dict[str, Any]]:
"""
Convert a SDF file to a Mujoco MJCF string.
Args:
sdf: The SDF file to convert.
considered_joints: The list of joint names to consider in the conversion.
model_name: The name of the model to convert.
plane_normal: The normal vector of the plane.
heightmap: Whether to generate a heightmap.
cameras: The list of cameras to add to the scene.
Returns:
tuple: A tuple containing the MJCF string and the assets dictionary.
"""
logging.warning("This method is deprecated. Use 'ModelToMjcf.convert' instead.")
# Get the ROD model.
rod_model = load_rod_model(
model_description=sdf,
is_urdf=False,
model_name=model_name,
)
# Convert the ROD model to MJCF.
return RodModelToMjcf.convert(
rod_model=rod_model,
considered_joints=considered_joints,
plane_normal=plane_normal,
heightmap=heightmap,
cameras=cameras,
)
================================================
FILE: src/jaxsim/mujoco/model.py
================================================
from __future__ import annotations
import functools
import pathlib
from collections.abc import Callable, Sequence
from typing import Any
import mujoco as mj
import numpy as np
import numpy.typing as npt
import xmltodict
from scipy.spatial.transform import Rotation
import jaxsim.typing as jtp
HeightmapCallable = Callable[[jtp.FloatLike, jtp.FloatLike], jtp.FloatLike]
class MujocoModelHelper:
"""
Helper class to create and interact with Mujoco models and data objects.
"""
def __init__(self, model: mj.MjModel, data: mj.MjData | None = None) -> None:
"""
Initialize the MujocoModelHelper object.
Args:
model: A Mujoco model object.
data: A Mujoco data object. If None, a new one will be created.
"""
self.model = model
self.data = data if data is not None else mj.MjData(self.model)
# Populate the data with kinematics.
mj.mj_forward(self.model, self.data)
# Keep the cache of this method local to improve GC.
self.mask_qpos = functools.cache(self._mask_qpos)
@staticmethod
def build_from_xml(
mjcf_description: str | pathlib.Path,
assets: dict[str, Any] | None = None,
heightmap: HeightmapCallable | None = None,
heightmap_name: str = "terrain",
heightmap_radius_xy: tuple[float, float] = (1.0, 1.0),
) -> MujocoModelHelper:
"""
Build a Mujoco model from an MJCF description.
Args:
mjcf_description:
A string containing the XML description of the Mujoco model
or a path to a file containing the XML description.
assets: An optional dictionary containing the assets of the model.
heightmap:
A function in two variables that returns the height of a terrain
in the specified coordinate point.
heightmap_name:
The default name of the heightmap in the MJCF description
to load the corresponding configuration.
heightmap_radius_xy:
The extension of the heightmap in the x-y surface corresponding to the
plane over which the grid of the sampled heightmap is generated.
Returns:
A MujocoModelHelper object.
"""
# Read the XML description if it is a path to file.
mjcf_description = (
mjcf_description.read_text()
if isinstance(mjcf_description, pathlib.Path)
else mjcf_description
)
if heightmap is None:
hfield = None
else:
mjcf_description_dict = xmltodict.parse(xml_input=mjcf_description)
# Create a dictionary of all hfield configurations from the MJCF.
hfields = mjcf_description_dict["mujoco"]["asset"].get("hfield", [])
hfields = hfields if isinstance(hfields, list) else [hfields]
hfields_dict = {hfield["@name"]: hfield for hfield in hfields}
if heightmap_name not in hfields_dict:
raise ValueError(f"Heightmap '{heightmap_name}' not found in MJCF")
hfield_element = hfields_dict[heightmap_name]
# Generate the hfield by sampling the heightmap function.
hfield = generate_hfield(
heightmap=heightmap,
samples_xy=(int(hfield_element["@nrow"]), int(hfield_element["@ncol"])),
radius_xy=heightmap_radius_xy,
)
# Update dynamically the '/asset/hfield[@name=heightmap_name]@size' attribute
# with the information of the sampled points.
# This is necessary for correctly rendering the heightmap over the
# specified xy area with the correct z elevation.
size = [float(el) for el in hfield_element["@size"].split(" ")]
size[0], size[1] = heightmap_radius_xy
size[2] = 1.0
# The following could be zero but Mujoco complains if it's exactly zero.
size[3] = max(0.000_001, -min(hfield))
# Replace the 'size' attribute.
hfields_dict[heightmap_name]["@size"] = " ".join(str(el) for el in size)
# Update the hfield elements of the original MJCF.
# Only the hfield corresponding to 'heightmap_name' was actually edited.
mjcf_description_dict["mujoco"]["asset"]["hfield"] = list(
hfields_dict.values()
)
# Serialize the updated MJCF to XML.
mjcf_description = xmltodict.unparse(
input_dict=mjcf_description_dict, pretty=True
)
# Create the Mujoco model from the XML and, optionally, the dictionary of assets.
model = mj.MjModel.from_xml_string(xml=mjcf_description, assets=assets)
data = mj.MjData(model)
# Store the sampled heightmap into the Mujoco model.
if heightmap is not None:
assert hfield is not None
model.hfield_data = hfield
return MujocoModelHelper(model=model, data=data)
def time(self) -> float:
"""Return the simulation time."""
return self.data.time
def timestep(self) -> float:
"""Return the simulation timestep."""
return self.model.opt.timestep
def gravity(self) -> npt.NDArray:
"""Return the 3D gravity vector."""
return np.array([0, 0, self.model.gravity])
# =========================
# Methods for the base link
# =========================
def is_floating_base(self) -> bool:
"""Return true if the model is floating-base."""
# A body with no joints is considered a fixed-base model.
# In fact, in mujoco, a floating-base model has a 6 DoFs first joint.
if self.number_of_joints() == 0:
return False
# We just check that the first joint has 6 DoFs.
joint0_type = self.model.jnt_type[0]
return joint0_type == mj.mjtJoint.mjJNT_FREE
def is_fixed_base(self) -> bool:
"""Return true if the model is fixed-base."""
return not self.is_floating_base()
def base_link(self) -> str:
"""Return the name of the base link."""
return mj.mj_id2name(
self.model, mj.mjtObj.mjOBJ_BODY, 0 if self.is_fixed_base() else 1
)
def base_position(self) -> npt.NDArray:
"""Return the 3D position of the base link."""
return (
self.data.qpos[:3]
if self.is_floating_base()
else self.body_position(body_name=self.base_link())
)
def base_orientation(self, dcm: bool = False) -> npt.NDArray:
"""Return the orientation of the base link."""
return (
(
np.reshape(self.data.xmat[0], newshape=(3, 3))
if dcm
else self.data.xquat[0]
)
if self.is_floating_base()
else self.body_orientation(body_name=self.base_link(), dcm=dcm)
)
def set_base_position(self, position: npt.NDArray) -> None:
"""Set the 3D position of the base link."""
if self.is_fixed_base():
raise ValueError("The position of a fixed-base model cannot be set.")
position = np.atleast_1d(np.array(position).squeeze())
if position.size != 3:
raise ValueError(f"Wrong position size ({position.size})")
self.data.qpos[:3] = position
def set_base_orientation(self, orientation: npt.NDArray, dcm: bool = False) -> None:
"""Set the 3D position of the base link."""
if self.is_fixed_base():
raise ValueError("The orientation of a fixed-base model cannot be set.")
orientation = (
np.atleast_2d(np.array(orientation).squeeze())
if dcm
else np.atleast_1d(np.array(orientation).squeeze())
)
if orientation.shape != ((4,) if not dcm else (3, 3)):
raise ValueError(f"Wrong orientation shape {orientation.shape}")
def is_quaternion(Q):
return np.allclose(np.linalg.norm(Q), 1.0)
def is_dcm(R):
return np.allclose(np.linalg.det(R), 1.0) and np.allclose(
R.T @ R, np.eye(3)
)
if not (is_quaternion(orientation) if not dcm else is_dcm(orientation)):
raise ValueError("The orientation is not a valid element of SO(3)")
W_Q_B = (
Rotation.from_matrix(orientation).as_quat(
canonical=True, scalar_first=False
)
if dcm
else orientation
)
self.data.qpos[3:7] = W_Q_B
# ==================
# Methods for joints
# ==================
def number_of_joints(self) -> int:
"""Return the number of joints in the model."""
return self.model.njnt
def number_of_dofs(self) -> int:
"""Return the number of DoFs in the model."""
return self.model.nq
def joint_names(self) -> list[str]:
"""Return the names of the joints in the model."""
return [
mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_JOINT, idx)
for idx in range(0 if self.is_fixed_base() else 1, self.number_of_joints())
]
def joint_dofs(self, joint_name: str) -> int:
"""Return the number of DoFs of a joint."""
if joint_name not in self.joint_names():
raise ValueError(f"Joint '{joint_name}' not found")
return self.data.joint(joint_name).qpos.size
def joint_position(self, joint_name: str) -> npt.NDArray:
"""Return the position of a joint."""
if joint_name not in self.joint_names():
raise ValueError(f"Joint '{joint_name}' not found")
return self.data.joint(joint_name).qpos
def joint_positions(self, joint_names: list[str] | None = None) -> npt.NDArray:
"""Return the positions of the joints."""
joint_names = joint_names if joint_names is not None else self.joint_names()
return np.hstack(
[self.joint_position(joint_name) for joint_name in joint_names]
)
def set_joint_position(
self, joint_name: str, position: npt.NDArray | float
) -> None:
"""Set the position of a joint."""
position = np.atleast_1d(np.array(position).squeeze())
if position.size != self.joint_dofs(joint_name=joint_name):
raise ValueError(
f"Wrong position size ({position.size}) of "
f"{self.joint_dofs(joint_name=joint_name)}-DoFs joint '{joint_name}'."
)
idx = mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, joint_name)
offset = self.model.jnt_qposadr[idx]
sl = np.s_[offset : offset + self.joint_dofs(joint_name=joint_name)]
self.data.qpos[sl] = position
def set_joint_positions(
self, joint_names: Sequence[str], positions: npt.NDArray | list[npt.NDArray]
) -> None:
"""Set the positions of multiple joints."""
mask = self.mask_qpos(joint_names=tuple(joint_names))
self.data.qpos[mask] = positions
# ==================
# Methods for bodies
# ==================
def number_of_bodies(self) -> int:
"""Return the number of bodies in the model."""
return self.model.nbody
def body_names(self) -> list[str]:
"""Return the names of the bodies in the model."""
return [
mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_BODY, idx)
for idx in range(self.number_of_bodies())
]
def body_position(self, body_name: str) -> npt.NDArray:
"""Return the position of a body."""
if body_name not in self.body_names():
raise ValueError(f"Body '{body_name}' not found")
return self.data.body(body_name).xpos
def body_orientation(self, body_name: str, dcm: bool = False) -> npt.NDArray:
"""Return the orientation of a body."""
if body_name not in self.body_names():
raise ValueError(f"Body '{body_name}' not found")
return (
self.data.body(body_name).xmat if dcm else self.data.body(body_name).xquat
)
# ======================
# Methods for geometries
# ======================
def number_of_geometries(self) -> int:
"""Return the number of geometries in the model."""
return self.model.ngeom
def geometry_names(self) -> list[str]:
"""Return the names of the geometries in the model."""
return [
mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_GEOM, idx)
for idx in range(self.number_of_geometries())
]
def geometry_position(self, geometry_name: str) -> npt.NDArray:
"""Return the position of a geometry."""
if geometry_name not in self.geometry_names():
raise ValueError(f"Geometry '{geometry_name}' not found")
return self.data.geom(geometry_name).xpos
def geometry_orientation(
self, geometry_name: str, dcm: bool = False
) -> npt.NDArray:
"""Return the orientation of a geometry."""
if geometry_name not in self.geometry_names():
raise ValueError(f"Geometry '{geometry_name}' not found")
R = np.reshape(self.data.geom(geometry_name).xmat, newshape=(3, 3))
if dcm:
return R
q_xyzw = Rotation.from_matrix(R).as_quat(canonical=True, scalar_first=False)
return q_xyzw
# ===============
# Private methods
# ===============
def _mask_qpos(self, joint_names: tuple[str, ...]) -> npt.NDArray:
"""
Create a mask to access the DoFs of the desired `joint_names` in the `qpos` array.
Args:
joint_names: A tuple containing the names of the joints.
Returns:
A 1D array containing the indices of the `qpos` array to access the DoFs of
the desired `joint_names`.
Note:
This method takes a tuple of strings because we cache the output mask for
each combination of joint names. We need a hashable object for the cache.
"""
# Get the indices of the joints in `joint_names`.
idxs = [
mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, joint_name)
for joint_name in joint_names
]
# We first get the index of each joint in the qpos array, and for those that
# have multiple DoFs, we expand their mask by appending new elements.
# Finally, we flatten the list of arrays to a single array, that is the
# final qpos mask accessing all the DoFs of the desired `joint_names`.
return np.atleast_1d(
np.hstack(
[
np.array(
[
self.model.jnt_qposadr[idx] + i
for i in range(self.joint_dofs(joint_name=joint_name))
]
)
for idx, joint_name in zip(idxs, joint_names, strict=True)
]
).squeeze()
)
def generate_hfield(
heightmap: HeightmapCallable,
samples_xy: tuple[int, int] = (11, 11),
radius_xy: tuple[float, float] = (1.0, 1.0),
) -> npt.NDArray:
"""
Generate an array with elevation points sampled from a heightmap function.
The map will have the following format:
```
heightmap[0, 0] heightmap[0, 1] ... heightmap[0, size[1]-1]
heightmap[1, 0] heightmap[1, 1] ... heightmap[1, size[1]-1]
...
heightmap[size[0]-1, 0] heightmap[size[0]-1, 1] ... heightmap[size[0]-1, size[1]-1]
```
Args:
heightmap:
A function that takes two arguments (x, y) and returns the height
at that point.
samples_xy: A tuple of two integers representing the size of the grid.
radius_xy:
A tuple of two floats representing extension of the heightmap in the
x-y surface corresponding to the area over which the grid of the sampled
heightmap is generated.
Returns:
A flat array of the sampled terrain heightmap.
"""
# Generate the grid.
x = np.linspace(-radius_xy[0], radius_xy[0], samples_xy[0])
y = np.linspace(-radius_xy[1], radius_xy[1], samples_xy[1])
# Generate the heightmap.
return np.array([[heightmap(xi, yi) for xi in x] for yi in y]).flatten()
================================================
FILE: src/jaxsim/mujoco/utils.py
================================================
from __future__ import annotations
import dataclasses
from collections.abc import Sequence
import mujoco as mj
import numpy as np
import numpy.typing as npt
from scipy.spatial.transform import Rotation
from .model import MujocoModelHelper
def mujoco_data_from_jaxsim(
mujoco_model: mj.MjModel,
jaxsim_model,
jaxsim_data,
mujoco_data: mj.MjData | None = None,
update_removed_joints: bool = True,
) -> mj.MjData:
"""
Create a Mujoco data object from a JaxSim model and data objects.
Args:
mujoco_model: The Mujoco model object corresponding to the JaxSim model.
jaxsim_model: The JaxSim model object from which the Mujoco model was created.
jaxsim_data: The JaxSim data object containing the state of the model.
mujoco_data: An optional Mujoco data object. If None, a new one will be created.
update_removed_joints:
If True, the positions of the joints that have been removed during the
model reduction process will be set to their initial values.
Returns:
The Mujoco data object containing the state of the JaxSim model.
Note:
This method is useful to initialize a Mujoco data object used for visualization
with the state of a JaxSim model. In particular, this function takes care of
initializing the positions of the joints that have been removed during the
model reduction process. After the initial creation of the Mujoco data object,
it's faster to update the state using an external MujocoModelHelper object.
"""
# The package `jaxsim.mujoco` is supposed to be jax-independent.
# We import all the JaxSim resources privately.
import jaxsim.api as js
if not isinstance(jaxsim_model, js.model.JaxSimModel):
raise ValueError("The `jaxsim_model` argument must be a JaxSimModel object.")
if not isinstance(jaxsim_data, js.data.JaxSimModelData):
raise ValueError("The `jaxsim_data` argument must be a JaxSimModelData object.")
# Create the helper to operate on the Mujoco model and data.
model_helper = MujocoModelHelper(model=mujoco_model, data=mujoco_data)
# If the model is fixed-base, the Mujoco model won't have the joint corresponding
# to the floating base, and the helper would raise an exception.
if jaxsim_model.floating_base():
# Set the model position.
model_helper.set_base_position(position=np.array(jaxsim_data.base_position))
# Set the model orientation.
model_helper.set_base_orientation(
orientation=np.array(jaxsim_data.base_orientation)
)
# Set the joint positions.
if jaxsim_model.dofs() > 0:
model_helper.set_joint_positions(
joint_names=list(jaxsim_model.joint_names()),
positions=np.array(jaxsim_data.joint_positions),
)
# Updating these joints is not necessary after the first time.
# Users can disable this update after initialization.
if update_removed_joints:
# Create a dictionary with the joints that have been removed for various reasons
# (like link lumping due to model reduction).
joints_removed_dict = {
j.name: j
for j in jaxsim_model.description._joints_removed
if j.name not in set(jaxsim_model.joint_names())
}
# Set the positions of the removed joints.
_ = [
model_helper.set_joint_position(
position=joints_removed_dict[joint_name].initial_position,
joint_name=joint_name,
)
# Select all original joint that have been removed from the JaxSim model
# that are still present in the Mujoco model.
for joint_name in joints_removed_dict
if joint_name in model_helper.joint_names()
]
# Return the mujoco data with updated kinematics.
mj.mj_forward(mujoco_model, model_helper.data)
return model_helper.data
@dataclasses.dataclass
class MujocoCamera:
"""
Helper class storing parameters of a Mujoco camera.
Refer to the official documentation for more details:
https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-camera
"""
mode: str = "fixed"
target: str | None = None
fovy: str = "45"
pos: str = "0 0 0"
quat: str | None = None
axisangle: str | None = None
xyaxes: str | None = None
zaxis: str | None = None
euler: str | None = None
name: str | None = None
@classmethod
def build(cls, **kwargs) -> MujocoCamera:
"""
Build a Mujoco camera from a dictionary.
"""
if not all(isinstance(value, str) for value in kwargs.values()):
raise ValueError(f"Values must be strings: {kwargs}")
return cls(**kwargs)
@staticmethod
def build_from_target_view(
camera_name: str,
mode: str = "fixed",
lookat: Sequence[float | int] | npt.NDArray = (0, 0, 0),
distance: float | int | npt.NDArray = 3,
azimuth: float | int | npt.NDArray = 90,
elevation: float | int | npt.NDArray = -45,
fovy: float | int | npt.NDArray = 45,
degrees: bool = True,
**kwargs,
) -> MujocoCamera:
"""
Create a custom camera that looks at a target point.
Note:
The choice of the parameters is easier if we imagine to consider a target
frame `T` whose origin is located over the lookat point and having the same
orientation of the world frame `W`. We also introduce a camera frame `C`
whose origin is located over the lower-left corner of the image, and having
the x-axis pointing right and the y-axis pointing up in image coordinates.
The camera renders what it sees in the -z direction of frame `C`.
Args:
camera_name: The name of the camera.
mode: Camera positioning mode:
- **"fixed"**: Fixed position and orientation relative to the body.
- **"track"**: Fixed offset from the body in world coordinates, constant orientation.
- **"trackcom"**: Like `"track"`, but relative to the center of mass of the subtree.
- **"targetbody"**: Fixed position in body frame, oriented toward a target body.
- **"targetbodycom"**: Like `"targetbody"`, but targets the subtree's center of mass.
lookat: The target point to look at (origin of `T`).
distance:
The distance from the target point (displacement between the origins
of `T` and `C`).
azimuth:
The rotation around z of the camera. With an angle of 0, the camera
would loot at the target point towards the positive x-axis of `T`.
elevation:
The rotation around the x-axis of the camera frame `C`. Note that if
you want to lift the view angle, the elevation is negative.
fovy: The field of view of the camera.
degrees: Whether the angles are in degrees or radians.
**kwargs: Additional camera parameters.
Returns:
The custom camera.
"""
# Start from a frame whose origin is located over the lookat point.
# We initialize a -90 degrees rotation around the z-axis because due to
# the default camera coordinate system (x pointing right, y pointing up).
W_H_C = np.eye(4)
W_H_C[0:3, 3] = np.array(lookat)
W_H_C[0:3, 0:3] = Rotation.from_euler(
seq="ZX", angles=[-90, 90], degrees=True
).as_matrix()
# Process the azimuth.
R_az = Rotation.from_euler(seq="Y", angles=azimuth, degrees=degrees).as_matrix()
W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_az
# Process elevation.
R_el = Rotation.from_euler(
seq="X", angles=elevation, degrees=degrees
).as_matrix()
W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_el
# Process distance.
tf_distance = np.eye(4)
tf_distance[2, 3] = distance
W_H_C = W_H_C @ tf_distance
# Extract the position and the quaternion.
p = W_H_C[0:3, 3]
Q = Rotation.from_matrix(W_H_C[0:3, 0:3]).as_quat(scalar_first=True)
return MujocoCamera.build(
name=camera_name,
mode=mode,
fovy=str(fovy if degrees else np.rad2deg(fovy)),
pos=" ".join(p.astype(str).tolist()),
quat=" ".join(Q.astype(str).tolist()),
**kwargs,
)
def asdict(self) -> dict[str, str]:
"""
Convert the camera to a dictionary.
"""
return {k: v for k, v in dataclasses.asdict(self).items() if v is not None}
================================================
FILE: src/jaxsim/mujoco/visualizer.py
================================================
import contextlib
import pathlib
from collections.abc import Iterator, Sequence
import mediapy as media
import mujoco as mj
import mujoco.viewer
import numpy as np
import numpy.typing as npt
from scipy.spatial.transform import Rotation
class MujocoVideoRecorder:
"""
Video recorder for the MuJoCo passive viewer.
"""
def __init__(
self,
model: list[mj.MjModel] | mj.MjModel,
data: list[mj.MjData] | mj.MjData,
fps: int = 30,
width: int | None = None,
height: int | None = None,
**kwargs,
) -> None:
"""
Initialize the Mujoco video recorder.
Args:
model: The Mujoco model.
data: The Mujoco data.
fps: The frames per second.
width: The width of the video.
height: The height of the video.
**kwargs: Additional arguments for the renderer.
"""
if isinstance(model, mj.MjModel):
single_model = model
elif isinstance(model, list) and len(model) == 1:
single_model = model[0]
else:
raise ValueError(
"Model must be a single instance of mj.MjModel or a list with at least one element."
)
width = width if width is not None else single_model.vis.global_.offwidth
height = height if height is not None else single_model.vis.global_.offheight
if single_model.vis.global_.offwidth != width:
single_model.vis.global_.offwidth = width
if single_model.vis.global_.offheight != height:
single_model.vis.global_.offheight = height
self.fps = fps
self.frames: list[npt.NDArray] = []
self.data: list[mj.MjData] | mj.MjData | None = None
self.model: list[mj.MjModel] | mj.MjModel | None = None
self.reset(model=model, data=data)
self.renderer = mujoco.Renderer(
model=single_model,
**(dict(width=width, height=height) | kwargs),
)
def visualize_frame(
self, frame_pose: list[float] | npt.NDArray | None = None
) -> None:
"""
Add visualization of a static frame.
Args:
frame_pose: The pose of a static frame to visualize as [x, y, z, roll, pitch, yaw].
"""
scene = self.renderer.scene
# Three free slots are needed for the axes (x, y, z).
if scene.ngeom + 3 > scene.maxgeom:
return
# Read position and RPY orientation
if not frame_pose:
return
try:
x, y, z, roll, pitch, yaw = frame_pose
except Exception as e:
raise ValueError(
"Frame pose elements must be a 6D list: 'x y z roll pitch yaw'"
) from e
mat = Rotation.from_euler("xyz", [roll, pitch, yaw], degrees=False).as_matrix()
origin = np.array([x, y, z])
length = 0.2 # length of axis cylinders
radius = 0.01 # slim radius for cylinders
for axis, color in zip(
range(3), [(1, 0, 0, 1), (0, 1, 0, 1), (0, 0, 1, 1)], strict=True
):
if scene.ngeom >= scene.maxgeom:
break
axis_dir = mat[:, axis]
geom = scene.geoms[scene.ngeom]
# Cylinder position is centered at origin + half length along axis
pos = origin + axis_dir * length * 0.5
# Build rotation matrix for cylinder aligned with axis_dir
# MuJoCo's cylinder local axis is along z-axis
def rot_from_z(v: np.ndarray) -> np.ndarray:
v = v / np.linalg.norm(v)
z_axis = np.array([0, 0, 1])
if np.allclose(v, z_axis):
return np.eye(3)
if np.allclose(v, -z_axis):
return np.diag([1, -1, -1])
cross = np.cross(z_axis, v)
dot = np.dot(z_axis, v)
skew = np.array(
[
[0, -cross[2], cross[1]],
[cross[2], 0, -cross[0]],
[-cross[1], cross[0], 0],
]
)
R = np.eye(3) + skew + skew @ skew * (1 / (1 + dot))
return R
R = rot_from_z(axis_dir)
mat_flat = R.flatten()
mj.mjv_initGeom(
geom=geom,
type=mj.mjtGeom.mjGEOM_CYLINDER,
# The `size` arguments takes three positional arguments.
# In the cylinder case, the first two are the radius and half-length,
# and the third is not used (set to 0.0).
size=np.array([radius, length * 0.5, 0.0]),
rgba=np.array(color),
pos=pos,
mat=mat_flat,
)
geom.category = mj.mjtCatBit.mjCAT_STATIC
scene.ngeom += 1
def reset(
self,
model: mj.MjModel | None = None,
data: list[mj.MjData] | mj.MjData | None = None,
) -> None:
"""Reset the model and data."""
self.frames = []
self.data = data if data is not None else self.data
self.data = self.data if isinstance(self.data, list) else [self.data]
self.model = model if model is not None else self.model
self.model = self.model if isinstance(self.model, list) else [self.model]
assert len(self.data) == len(self.model) or len(self.model) == 1, (
f"Length mismatch: len(data)={len(self.data)}, len(model)={len(self.model)}. "
"They must be equal or model must have length 1."
)
def render_frame(
self,
camera_name: str = "track",
frame_pose: list[float] | npt.NDArray | None = None,
) -> npt.NDArray:
"""
Render a frame.
Args:
camera_name: The name of the camera to use for rendering.
frame_pose: The pose of a static frame to visualize as [x, y, z, roll, pitch, yaw].
Returns:
The rendered frame as a NumPy array.
"""
for idx, data in enumerate(self.data):
# Use a single model for rendering if multiple data instances are provided.
# Otherwise, use the data index to select the corresponding model.
model = self.model[0] if len(self.model) == 1 else self.model[idx]
mj.mj_forward(model, data)
if idx == 0:
self.renderer.update_scene(data=data, camera=camera_name)
self.visualize_frame(frame_pose=frame_pose)
continue
mujoco.mjv_addGeoms(
m=model,
d=data,
opt=mj.MjvOption(),
pert=mj.MjvPerturb(),
catmask=mj.mjtCatBit.mjCAT_DYNAMIC,
scn=self.renderer.scene,
)
return self.renderer.render()
def record_frame(
self,
camera_name: str = "track",
frame_pose: list[float] | npt.NDArray | None = None,
) -> None:
"""Store a frame in the buffer."""
frame = self.render_frame(camera_name=camera_name, frame_pose=frame_pose)
self.frames.append(frame)
def write_video(self, path: pathlib.Path | str, exist_ok: bool = False) -> None:
"""Write the video to a file."""
# Resolve the path to the video.
path = pathlib.Path(path).expanduser().resolve()
if path.is_dir():
raise IsADirectoryError(f"The path '{path}' is a directory.")
if not exist_ok and path.is_file():
raise FileExistsError(f"The file '{path}' already exists.")
media.write_video(path=path, images=np.array(self.frames), fps=self.fps)
@staticmethod
def compute_down_sampling(original_fps: int, target_min_fps: int) -> int:
"""
Return the integer down-sampling factor to reach at least the target fps.
Args:
original_fps: The original fps.
target_min_fps: The target minimum fps.
Returns:
The down-sampling factor.
"""
down_sampling = 1
down_sampling_final = down_sampling
while original_fps / (down_sampling + 1) >= target_min_fps:
down_sampling = down_sampling + 1
if int(original_fps / down_sampling) == original_fps / down_sampling:
down_sampling_final = down_sampling
return down_sampling_final
class MujocoVisualizer:
"""
Visualizer for the MuJoCo passive viewer.
"""
def __init__(
self, model: mj.MjModel | None = None, data: mj.MjData | None = None
) -> None:
"""
Initialize the Mujoco visualizer.
Args:
model: The Mujoco model.
data: The Mujoco data.
"""
self.data = data
self.model = model
def sync(
self,
viewer: mj.viewer.Handle,
model: mj.MjModel | None = None,
data: mj.MjData | None = None,
) -> None:
"""Update the viewer with the current model and data."""
data = data if data is not None else self.data
model = model if model is not None else self.model
mj.mj_forward(model, data)
viewer.sync()
def open_viewer(
self,
model: mj.MjModel | None = None,
data: mj.MjData | None = None,
show_left_ui: bool = False,
) -> mj.viewer.Handle:
"""Open a viewer."""
data = data if data is not None else self.data
model = model if model is not None else self.model
handle = mj.viewer.launch_passive(
model, data, show_left_ui=show_left_ui, show_right_ui=False
)
return handle
@contextlib.contextmanager
def open(
self,
model: mj.MjModel | None = None,
data: mj.MjData | None = None,
*,
show_left_ui: bool = False,
close_on_exit: bool = True,
lookat: Sequence[float | int] | npt.NDArray | None = None,
distance: float | int | npt.NDArray | None = None,
azimuth: float | int | npt.NDArray | None = None,
elevation: float | int | npt.NDArray | None = None,
) -> Iterator[mj.viewer.Handle]:
"""
Context manager to open the Mujoco passive viewer.
Note:
Refer to the Mujoco documentation for details of the camera options:
https://mujoco.readthedocs.io/en/stable/XMLreference.html#visual-global
"""
handle = self.open_viewer(model=model, data=data, show_left_ui=show_left_ui)
handle = MujocoVisualizer.setup_viewer_camera(
viewer=handle,
lookat=lookat,
distance=distance,
azimuth=azimuth,
elevation=elevation,
)
try:
yield handle
finally:
_ = handle.close() if close_on_exit else None
@staticmethod
def setup_viewer_camera(
viewer: mj.viewer.Handle,
*,
lookat: Sequence[float | int] | npt.NDArray | None,
distance: float | int | npt.NDArray | None = None,
azimuth: float | int | npt.NDArray | None = None,
elevation: float | int | npt.NDArray | None = None,
) -> mj.viewer.Handle:
"""
Configure the initial viewpoint of the Mujoco passive viewer.
Note:
Refer to the Mujoco documentation for details of the camera options:
https://mujoco.readthedocs.io/en/stable/XMLreference.html#visual-global
Returns:
The viewer with configured camera.
"""
if lookat is not None:
lookat_array = np.array(lookat, dtype=float).squeeze()
if lookat_array.size != 3:
raise ValueError(lookat)
viewer.cam.lookat = lookat_array
if distance is not None:
viewer.cam.distance = float(distance)
if azimuth is not None:
viewer.cam.azimuth = float(azimuth) % 360
if elevation is not None:
viewer.cam.elevation = float(elevation)
return viewer
================================================
FILE: src/jaxsim/parsers/__init__.py
================================================
================================================
FILE: src/jaxsim/parsers/descriptions/__init__.py
================================================
from .collision import (
BoxCollision,
CollidablePoint,
CollisionShape,
MeshCollision,
SphereCollision,
)
from .joint import JointDescription, JointGenericAxis, JointType
from .link import LinkDescription
from .model import ModelDescription
================================================
FILE: src/jaxsim/parsers/descriptions/collision.py
================================================
from __future__ import annotations
import abc
import dataclasses
import jax.numpy as jnp
import numpy as np
import numpy.typing as npt
import jaxsim.typing as jtp
from jaxsim import logging
from .link import LinkDescription
@dataclasses.dataclass
class CollidablePoint:
"""
Represents a collidable point associated with a parent link.
Attributes:
parent_link: The parent link to which the collidable point is attached.
position: The position of the collidable point relative to the parent link.
enabled: A flag indicating whether the collidable point is enabled for collision detection.
"""
parent_link: LinkDescription
position: npt.NDArray = dataclasses.field(default_factory=lambda: np.zeros(3))
enabled: bool = True
def change_link(
self, new_link: LinkDescription, new_H_old: npt.NDArray
) -> CollidablePoint:
"""
Move the collidable point to a new parent link.
Args:
new_link (LinkDescription): The new parent link to which the collidable point is moved.
new_H_old (npt.NDArray): The transformation matrix from the new link's frame to the old link's frame.
Returns:
CollidablePoint: A new collidable point associated with the new parent link.
"""
msg = f"Moving collidable point: {self.parent_link.name} -> {new_link.name}"
logging.debug(msg=msg)
return CollidablePoint(
parent_link=new_link,
position=(new_H_old @ jnp.hstack([self.position, 1.0])).squeeze()[0:3],
enabled=self.enabled,
)
def __hash__(self) -> int:
return hash(
(
hash(self.parent_link),
hash(tuple(self.position.tolist())),
hash(self.enabled),
)
)
def __eq__(self, other: CollidablePoint) -> bool:
if not isinstance(other, CollidablePoint):
return False
return hash(self) == hash(other)
def __str__(self) -> str:
return (
f"{self.__class__.__name__}("
+ f"parent_link={self.parent_link.name}"
+ f", position={self.position}"
+ f", enabled={self.enabled}"
+ ")"
)
@dataclasses.dataclass
class CollisionShape(abc.ABC):
"""
Abstract base class for representing collision shapes.
Attributes:
collidable_points: A list of collidable points associated with the collision shape.
"""
collidable_points: tuple[CollidablePoint]
def __str__(self):
return (
f"{self.__class__.__name__}("
+ "collidable_points=[\n "
+ ",\n ".join(str(cp) for cp in self.collidable_points)
+ "\n])"
)
@dataclasses.dataclass
class BoxCollision(CollisionShape):
"""
Represents a box-shaped collision shape.
Attributes:
center: The center of the box in the local frame of the collision shape.
"""
center: jtp.VectorLike
def __hash__(self) -> int:
return hash(
(
hash(super()),
hash(tuple(self.center.tolist())),
)
)
def __eq__(self, other: BoxCollision) -> bool:
if not isinstance(other, BoxCollision):
return False
return hash(self) == hash(other)
@dataclasses.dataclass
class SphereCollision(CollisionShape):
"""
Represents a spherical collision shape.
Attributes:
center: The center of the sphere in the local frame of the collision shape.
"""
center: jtp.VectorLike
def __hash__(self) -> int:
return hash(
(
hash(super()),
hash(tuple(self.center.tolist())),
)
)
def __eq__(self, other: BoxCollision) -> bool:
if not isinstance(other, BoxCollision):
return False
return hash(self) == hash(other)
@dataclasses.dataclass
class MeshCollision(CollisionShape):
"""
Represents a mesh-shaped collision shape.
Attributes:
center: The center of the mesh in the local frame of the collision shape.
"""
center: jtp.VectorLike
def __hash__(self) -> int:
return hash(
(
hash(tuple(self.center.tolist())),
hash(self.collidable_points),
)
)
def __eq__(self, other: MeshCollision) -> bool:
if not isinstance(other, MeshCollision):
return False
return hash(self) == hash(other)
================================================
FILE: src/jaxsim/parsers/descriptions/joint.py
================================================
from __future__ import annotations
import dataclasses
from typing import ClassVar
import jax_dataclasses
import numpy as np
import jaxsim.typing as jtp
from jaxsim.utils import JaxsimDataclass, Mutability
from .link import LinkDescription
@dataclasses.dataclass(frozen=True)
class JointType:
"""
Enumeration of joint types.
"""
Fixed: ClassVar[int] = 0
Revolute: ClassVar[int] = 1
Prismatic: ClassVar[int] = 2
@jax_dataclasses.pytree_dataclass
class JointGenericAxis:
"""
A joint requiring the specification of a 3D axis.
"""
# The axis of rotation or translation of the joint (must have norm 1).
axis: jtp.Vector
def __hash__(self) -> int:
return hash(tuple(self.axis.tolist()))
def __eq__(self, other: JointGenericAxis) -> bool:
if not isinstance(other, JointGenericAxis):
return False
return hash(self) == hash(other)
@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
class JointDescription(JaxsimDataclass):
"""
In-memory description of a robot link.
Attributes:
name: The name of the joint.
axis: The axis of rotation or translation for the joint.
pose: The pose transformation matrix of the joint.
jtype: The type of the joint.
child: The child link attached to the joint.
parent: The parent link attached to the joint.
index: An optional index for the joint.
friction_static: The static friction coefficient for the joint.
friction_viscous: The viscous friction coefficient for the joint.
position_limit_damper: The damper coefficient for position limits.
position_limit_spring: The spring coefficient for position limits.
position_limit: The position limits for the joint.
initial_position: The initial position of the joint.
"""
name: jax_dataclasses.Static[str]
axis: jtp.Vector
pose: jtp.Matrix
jtype: jax_dataclasses.Static[jtp.IntLike]
child: LinkDescription = dataclasses.dataclass(repr=False)
parent: LinkDescription = dataclasses.dataclass(repr=False)
index: jtp.IntLike | None = None
friction_static: jtp.FloatLike = 0.0
friction_viscous: jtp.FloatLike = 0.0
position_limit_damper: jtp.FloatLike = 0.0
position_limit_spring: jtp.FloatLike = 0.0
position_limit: tuple[jtp.FloatLike, jtp.FloatLike] = (0.0, 0.0)
initial_position: jtp.FloatLike | jtp.VectorLike = 0.0
motor_inertia: jtp.FloatLike = 0.0
motor_viscous_friction: jtp.FloatLike = 0.0
motor_gear_ratio: jtp.FloatLike = 1.0
def __post_init__(self) -> None:
if self.axis is not None:
with self.mutable_context(
mutability=Mutability.MUTABLE, restore_after_exception=False
):
norm_of_axis = np.linalg.norm(self.axis)
self.axis = self.axis / norm_of_axis
def __eq__(self, other: JointDescription) -> bool:
if not isinstance(other, JointDescription):
return False
return hash(self) == hash(other)
def __hash__(self) -> int:
from jaxsim.utils.wrappers import HashedNumpyArray
return hash(
(
hash(self.name),
HashedNumpyArray.hash_of_array(self.axis),
HashedNumpyArray.hash_of_array(self.pose),
hash(int(self.jtype)),
hash(self.child),
hash(self.parent),
hash(int(self.index)) if self.index is not None else 0,
HashedNumpyArray.hash_of_array(self.friction_static),
HashedNumpyArray.hash_of_array(self.friction_viscous),
HashedNumpyArray.hash_of_array(self.position_limit_damper),
HashedNumpyArray.hash_of_array(self.position_limit_spring),
HashedNumpyArray.hash_of_array(self.position_limit),
HashedNumpyArray.hash_of_array(self.initial_position),
HashedNumpyArray.hash_of_array(self.motor_inertia),
HashedNumpyArray.hash_of_array(self.motor_viscous_friction),
HashedNumpyArray.hash_of_array(self.motor_gear_ratio),
),
)
================================================
FILE: src/jaxsim/parsers/descriptions/link.py
================================================
from __future__ import annotations
import dataclasses
import jax.numpy as jnp
import jax_dataclasses
import numpy as np
from jax_dataclasses import Static
import jaxsim.typing as jtp
from jaxsim.math import Adjoint
from jaxsim.utils import JaxsimDataclass
@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
class LinkDescription(JaxsimDataclass):
"""
In-memory description of a robot link.
Attributes:
name: The name of the link.
mass: The mass of the link.
inertia: The inertia tensor of the link.
index: An optional index for the link (it gets automatically assigned).
parent: The parent link of this link.
pose: The pose transformation matrix of the link.
children: The children links.
"""
name: Static[str]
mass: float = dataclasses.field(repr=False)
inertia: jtp.Matrix = dataclasses.field(repr=False)
index: int | None = None
parent_name: Static[str | None] = dataclasses.field(default=None, repr=False)
pose: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.eye(4), repr=False)
children: Static[tuple[LinkDescription]] = dataclasses.field(
default_factory=list, repr=False
)
def __hash__(self) -> int:
from jaxsim.utils.wrappers import HashedNumpyArray
return hash(
(
hash(self.name),
hash(float(self.mass)),
HashedNumpyArray.hash_of_array(self.inertia),
hash(int(self.index)) if self.index is not None else 0,
HashedNumpyArray.hash_of_array(self.pose),
hash(tuple(self.children)),
hash(self.parent_name) if self.parent_name is not None else 0,
)
)
def __eq__(self, other: LinkDescription) -> bool:
if not isinstance(other, LinkDescription):
return False
if not (
self.name == other.name
and np.allclose(self.mass, other.mass)
and np.allclose(self.inertia, other.inertia)
and self.index == other.index
and np.allclose(self.pose, other.pose)
and self.children == other.children
and self.parent_name == other.parent_name
):
return False
return True
@property
def name_and_index(self) -> str:
"""
Get a formatted string with the link's name and index.
Returns:
str: The formatted string.
"""
return f"#{self.index}_<{self.name}>"
def lump_with(
self, link: LinkDescription, lumped_H_removed: jtp.Matrix
) -> LinkDescription:
"""
Combine the current link with another link, preserving mass and inertia.
Args:
link: The link to combine with.
lumped_H_removed: The transformation matrix between the two links.
Returns:
The combined link.
"""
# Get the 6D inertia of the link to remove.
I_removed = link.inertia
# Create the SE3 object. Note the inverse.
r_X_l = Adjoint.from_transform(transform=lumped_H_removed, inverse=True)
# Move the inertia
I_removed_in_lumped_frame = r_X_l.transpose() @ I_removed @ r_X_l
# Create the new combined link
lumped_link = self.replace(
mass=self.mass + link.mass,
inertia=self.inertia + I_removed_in_lumped_frame,
)
return lumped_link
================================================
FILE: src/jaxsim/parsers/descriptions/model.py
================================================
from __future__ import annotations
import dataclasses
import itertools
from collections.abc import Sequence
from jaxsim import logging
from jaxsim.logging import jaxsim_warn
from ..kinematic_graph import KinematicGraph, KinematicGraphTransforms, RootPose
from .collision import CollidablePoint, CollisionShape
from .joint import JointDescription
from .link import LinkDescription
@dataclasses.dataclass(frozen=True, eq=False, unsafe_hash=False)
class ModelDescription(KinematicGraph):
"""
Intermediate representation representing the kinematic graph of a robot model.
Attributes:
name: The name of the model.
fixed_base: Whether the model is either fixed-base or floating-base.
collision_shapes: List of collision shapes associated with the model.
"""
name: str = None
fixed_base: bool = True
collision_shapes: tuple[CollisionShape, ...] = dataclasses.field(
default_factory=list, repr=False
)
@staticmethod
def build_model_from(
name: str,
links: list[LinkDescription],
joints: list[JointDescription],
frames: list[LinkDescription] | None = None,
collisions: tuple[CollisionShape, ...] = (),
fixed_base: bool = False,
base_link_name: str | None = None,
considered_joints: Sequence[str] | None = None,
model_pose: RootPose = RootPose(),
) -> ModelDescription:
"""
Build a model description from provided components.
Args:
name: The name of the model.
links: List of link descriptions.
joints: List of joint descriptions.
frames: List of frame descriptions.
collisions: List of collision shapes associated with the model.
fixed_base: Indicates whether the model has a fixed base.
base_link_name: Name of the base link (i.e. the root of the kinematic tree).
considered_joints: List of joint names to consider (by default all joints).
model_pose: Pose of the model's root (by default an identity transform).
Returns:
A ModelDescription instance representing the model.
"""
# Create the full kinematic graph.
kinematic_graph = KinematicGraph.build_from(
links=links,
joints=joints,
frames=frames,
root_link_name=base_link_name,
root_pose=model_pose,
)
# Reduce the graph if needed.
if considered_joints is not None:
kinematic_graph = kinematic_graph.reduce(
considered_joints=considered_joints
)
# Create the object to compute forward kinematics.
fk = KinematicGraphTransforms(graph=kinematic_graph)
# Container of the final model's collision shapes.
final_collisions: list[CollisionShape] = []
# Move and express the collision shapes of removed links to the resulting
# lumped link that replace the combination of the removed link and its parent.
for collision_shape in collisions:
# Assume they have an unique parent link
if (
len({cp.parent_link.name for cp in collision_shape.collidable_points})
!= 1
):
msg = "Collision shape not currently supported (multiple parent links)"
raise RuntimeError(msg)
# Get the parent link of the collision shape.
# Note that this link could have been lumped and we need to find the
# link in which it was lumped into.
parent_link_of_shape = collision_shape.collidable_points[0].parent_link
# If it is part of the (reduced) graph, add it as it is...
if parent_link_of_shape.name in kinematic_graph.link_names():
final_collisions.append(collision_shape)
continue
# ... otherwise look for the frame
if parent_link_of_shape.name not in kinematic_graph.frame_names():
msg = "Parent frame '{}' of collision shape not found, ignoring shape"
logging.info(msg.format(parent_link_of_shape.name))
continue
# Create a new collision shape
new_collision_shape = CollisionShape(collidable_points=())
final_collisions.append(new_collision_shape)
# If the frame was found, update the collidable points' pose and add them
# to the new collision shape.
for cp in collision_shape.collidable_points:
# Find the link that is part of the (reduced) model in which the
# collision shape's parent was lumped into
real_parent_link_name = kinematic_graph.frames_dict[
parent_link_of_shape.name
].parent_name
# Change the link associated to the collidable point, updating their
# relative pose
moved_cp = cp.change_link(
new_link=kinematic_graph.links_dict[real_parent_link_name],
new_H_old=fk.relative_transform(
relative_to=real_parent_link_name,
name=cp.parent_link.name,
),
)
# Store the updated collision.
new_collision_shape.collidable_points += (moved_cp,)
# Build the model
model = ModelDescription(
name=name,
root_pose=kinematic_graph.root_pose,
fixed_base=fixed_base,
collision_shapes=tuple(final_collisions),
root=kinematic_graph.root,
joints=kinematic_graph.joints,
frames=kinematic_graph.frames,
_joints_removed=kinematic_graph.joints_removed,
)
# Check that the root link of kinematic graph is the desired base link.
assert kinematic_graph.root.name == base_link_name, kinematic_graph.root.name
return model
def reduce(self, considered_joints: Sequence[str]) -> ModelDescription:
"""
Reduce the model by removing specified joints.
Args:
considered_joints: Sequence of joint names to consider.
Returns:
A `ModelDescription` instance that only includes the considered joints.
"""
jaxsim_warn(
"The joint order in the model description is not preserved when reducing "
"the model. Consider using the `names_to_indices` method to get the correct "
"order of the joints, or use the `joint_names()` method to inspect the internal joint ordering."
)
if set(considered_joints) - set(self.joint_names()):
extra_joints = set(considered_joints) - set(self.joint_names())
msg = f"Found joints not part of the model: {extra_joints}"
raise ValueError(msg)
reduced_model_description = ModelDescription.build_model_from(
name=self.name,
links=list(self.links_dict.values()),
joints=self.joints,
frames=self.frames,
collisions=self.collision_shapes,
fixed_base=self.fixed_base,
base_link_name=next(iter(self)).name,
model_pose=self.root_pose,
considered_joints=considered_joints,
)
# Include the unconnected/removed joints from the original model.
for joint in self.joints_removed:
reduced_model_description.joints_removed.append(joint)
return reduced_model_description
def update_collision_shape_of_link(self, link_name: str, enabled: bool) -> None:
"""
Enable or disable collision shapes associated with a link.
Args:
link_name: The name of the link.
enabled: Enable or disable collision shapes associated with the link.
"""
if link_name not in self.link_names():
raise ValueError(link_name)
for point in self.collision_shape_of_link(
link_name=link_name
).collidable_points:
point.enabled = enabled
def collision_shape_of_link(self, link_name: str) -> CollisionShape:
"""
Get the collision shape associated with a specific link.
Args:
link_name: The name of the link.
Returns:
The collision shape associated with the link.
"""
if link_name not in self.link_names():
raise ValueError(link_name)
return CollisionShape(
collidable_points=[
point
for shape in self.collision_shapes
for point in shape.collidable_points
if point.parent_link.name == link_name
]
)
def all_enabled_collidable_points(self) -> list[CollidablePoint]:
"""
Get all enabled collidable points in the model.
Returns:
The list of all enabled collidable points.
"""
# Get iterator of all collidable points
all_collidable_points = itertools.chain.from_iterable(
[shape.collidable_points for shape in self.collision_shapes]
)
# Return enabled collidable points
return [cp for cp in all_collidable_points if cp.enabled]
def __eq__(self, other: ModelDescription) -> bool:
if not isinstance(other, ModelDescription):
return False
if not (
self.name == other.name
and self.fixed_base == other.fixed_base
and self.root == other.root
and self.joints == other.joints
and self.frames == other.frames
and self.root_pose == other.root_pose
):
return False
return True
def __hash__(self) -> int:
return hash(
(
hash(self.name),
hash(self.fixed_base),
hash(self.root),
hash(tuple(self.joints)),
hash(tuple(self.frames)),
hash(self.root_pose),
)
)
================================================
FILE: src/jaxsim/parsers/kinematic_graph.py
================================================
from __future__ import annotations
import copy
import dataclasses
import functools
from collections.abc import Callable, Iterable, Iterator, Sequence
from typing import Any
import numpy as np
import numpy.typing as npt
import jaxsim.utils
from jaxsim import logging
from jaxsim.utils import Mutability
from .descriptions.joint import JointDescription, JointType
from .descriptions.link import LinkDescription
@dataclasses.dataclass
class RootPose:
"""
Represents the root pose in a kinematic graph.
Attributes:
root_position: The 3D position of the root link of the graph.
root_quaternion:
The quaternion representing the rotation of the root link of the graph.
Note:
The root link of the kinematic graph is the base link.
"""
root_position: npt.NDArray = dataclasses.field(default_factory=lambda: np.zeros(3))
root_quaternion: npt.NDArray = dataclasses.field(
default_factory=lambda: np.array([1.0, 0, 0, 0])
)
def __hash__(self) -> int:
from jaxsim.utils.wrappers import HashedNumpyArray
return hash(
(
HashedNumpyArray.hash_of_array(self.root_position),
HashedNumpyArray.hash_of_array(self.root_quaternion),
)
)
def __eq__(self, other: RootPose) -> bool:
if not isinstance(other, RootPose):
return False
if not np.allclose(self.root_position, other.root_position):
return False
if not np.allclose(self.root_quaternion, other.root_quaternion):
return False
return True
@dataclasses.dataclass(frozen=True)
class KinematicGraph(Sequence[LinkDescription]):
"""
Class storing a kinematic graph having links as nodes and joints as edges.
Attributes:
root: The root node of the kinematic graph.
frames: List of frames rigidly attached to the graph nodes.
joints: List of joints connecting the graph nodes.
root_pose: The pose of the kinematic graph's root.
"""
root: LinkDescription
frames: list[LinkDescription] = dataclasses.field(
default_factory=list, hash=False, compare=False
)
joints: list[JointDescription] = dataclasses.field(
default_factory=list, hash=False, compare=False
)
root_pose: RootPose = dataclasses.field(default_factory=RootPose)
# Private attribute storing optional additional info.
_extra_info: dict[str, Any] = dataclasses.field(
default_factory=dict, repr=False, hash=False, compare=False
)
# Private attribute storing the unconnected joints from the parsed model and
# the joints removed after model reduction.
_joints_removed: list[JointDescription] = dataclasses.field(
default_factory=list, repr=False, hash=False, compare=False
)
@functools.cached_property
def links_dict(self) -> dict[str, LinkDescription]:
"""
Get a dictionary of links indexed by their name.
"""
return {l.name: l for l in iter(self)}
@functools.cached_property
def frames_dict(self) -> dict[str, LinkDescription]:
"""
Get a dictionary of frames indexed by their name.
"""
return {f.name: f for f in self.frames}
@functools.cached_property
def joints_dict(self) -> dict[str, JointDescription]:
"""
Get a dictionary of joints indexed by their name.
"""
return {j.name: j for j in self.joints}
@functools.cached_property
def joints_connection_dict(
self,
) -> dict[tuple[str, str], JointDescription]:
"""
Get a dictionary of joints indexed by the tuple (parent, child) link names.
"""
return {(j.parent.name, j.child.name): j for j in self.joints}
def __post_init__(self) -> None:
# Assign the link index by traversing the graph with BFS.
# Here we assume the model being fixed-base, therefore the base link will
# have index 0. We will deal with the floating base in a later stage.
for index, link in enumerate(self):
link.mutable(validate=False).index = index
# Get the names of the links, frames, and joints.
link_names = [l.name for l in self]
frame_names = [f.name for f in self.frames]
joint_names = [j.name for j in self.joints]
# Make sure that they are unique.
assert len(link_names) == len(set(link_names))
assert len(frame_names) == len(set(frame_names))
assert len(joint_names) == len(set(joint_names))
assert set(link_names).isdisjoint(set(frame_names))
assert set(link_names).isdisjoint(set(joint_names))
# Order frames with their name.
super().__setattr__("frames", sorted(self.frames, key=lambda f: f.name))
# Assign the frame index following the name-based indexing.
# We assume the model being fixed-base, therefore the first frame will
# have last_link_idx + 1.
for index, frame in enumerate(self.frames):
with frame.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
frame.index = index + len(self.link_names())
# Number joints so that their index matches their child link index.
# Therefore, the first joint has index 1.
links_dict = {l.name: l for l in iter(self)}
for joint in self.joints:
with joint.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
joint.index = links_dict[joint.child.name].index
# Check that joint indices are unique.
assert len([j.index for j in self.joints]) == len(
{j.index for j in self.joints}
)
# Order joints with their indices.
super().__setattr__("joints", sorted(self.joints, key=lambda j: j.index))
@staticmethod
def build_from(
links: list[LinkDescription],
joints: list[JointDescription],
frames: list[LinkDescription] | None = None,
root_link_name: str | None = None,
root_pose: RootPose = RootPose(),
) -> KinematicGraph:
"""
Build a KinematicGraph from links, joints, and frames.
Args:
links: A list of link descriptions.
joints: A list of joint descriptions.
frames: A list of frame descriptions.
root_link_name:
The name of the root link. If not provided, it's assumed to be the
first link's name.
root_pose: The root pose of the kinematic graph.
Returns:
The resulting kinematic graph.
"""
# Consider the first link as the root link if not provided.
if root_link_name is None:
root_link_name = links[0].name
logging.debug(msg=f"Assuming '{root_link_name}' as the root link")
# Couple links and joints and create the graph of links.
# Note that the pose of the frames is not updated; it is the caller's
# responsibility to update their pose if they want to use them.
(
graph_root_node,
graph_joints,
graph_frames,
unconnected_links,
unconnected_joints,
unconnected_frames,
) = KinematicGraph._create_graph(
links=links, joints=joints, root_link_name=root_link_name, frames=frames
)
for link in unconnected_links:
logging.warning(msg=f"Ignoring unconnected link: '{link.name}'")
for joint in unconnected_joints:
logging.warning(msg=f"Ignoring unconnected joint: '{joint.name}'")
for frame in unconnected_frames:
logging.warning(msg=f"Ignoring unconnected frame: '{frame.name}'")
return KinematicGraph(
root=graph_root_node,
joints=graph_joints,
frames=graph_frames,
root_pose=root_pose,
_joints_removed=unconnected_joints,
)
@staticmethod
def _create_graph(
links: list[LinkDescription],
joints: list[JointDescription],
root_link_name: str,
frames: list[LinkDescription] | None = None,
) -> tuple[
LinkDescription,
list[JointDescription],
list[LinkDescription],
list[LinkDescription],
list[JointDescription],
list[LinkDescription],
]:
"""
Low-level creator of kinematic graph components.
Args:
links: A list of parsed link descriptions.
joints: A list of parsed joint descriptions.
root_link_name: The name of the root link used as root node of the graph.
frames: A list of parsed frame descriptions.
Returns:
A tuple containing the root node of the graph (defining the entire kinematic
tree by iterating on its child nodes), the list of joints representing the
actual graph edges, the list of frames rigidly attached to the graph nodes,
the list of unconnected links, the list of unconnected joints, and the list
of unconnected frames.
"""
# Create a dictionary that maps the link name to the link, for easy retrieval.
links_dict: dict[str, LinkDescription] = {
l.name: l.mutable(validate=False) for l in links
}
# Create an empty list of frames if not provided.
frames = frames if frames is not None else []
# Create a dictionary that maps the frame name to the frame, for easy retrieval.
frames_dict = {frame.name: frame for frame in frames}
# Check that our parser correctly resolved the frame's parent to be a link.
for frame in frames:
assert frame.parent_name != "", frame
assert frame.parent_name is not None, frame
assert frame.parent_name != "__model__", frame
assert frame.parent_name not in frames_dict, frame
# ===========================================================
# Populate the kinematic graph with links, joints, and frames
# ===========================================================
# Check the existence of the root link.
if root_link_name not in links_dict:
raise ValueError(root_link_name)
# Reset the connections of the root link.
for link in links_dict.values():
link.children = tuple()
# Couple links and joints creating the kinematic graph.
for joint in joints:
# Get the parent and child links of the joint.
parent_link = links_dict[joint.parent.name]
child_link = links_dict[joint.child.name]
assert child_link.name == joint.child.name
assert parent_link.name == joint.parent.name
# Assign link's parent.
child_link.parent_name = parent_link.name
# Assign link's children and make sure they are unique.
if child_link.name not in {l.name for l in parent_link.children}:
with parent_link.mutable_context(Mutability.MUTABLE_NO_VALIDATION):
parent_link.children = (*parent_link.children, child_link)
# Collect all the links of the kinematic graph.
all_links_in_graph = list(
KinematicGraph.breadth_first_search(root=links_dict[root_link_name])
)
# Get the names of all links in the kinematic graph.
all_link_names_in_graph = [l.name for l in all_links_in_graph]
# Collect all the joints of the kinematic graph.
all_joints_in_graph = [
joint
for joint in joints
if joint.parent.name in all_link_names_in_graph
and joint.child.name in all_link_names_in_graph
]
# Get the names of all joints in the kinematic graph.
all_joint_names_in_graph = [j.name for j in all_joints_in_graph]
# Collect all the frames of the kinematic graph.
# Note: our parser ensures that the parent of a frame is not another frame.
all_frames_in_graph = [
frame for frame in frames if frame.parent_name in all_link_names_in_graph
]
# Get the names of all frames in the kinematic graph.
all_frames_names_in_graph = [f.name for f in all_frames_in_graph]
# ============================
# Collect unconnected elements
# ============================
# Collect all the joints that are not part of the kinematic graph.
removed_joints = [j for j in joints if j.name not in all_joint_names_in_graph]
for joint in removed_joints:
msg = "Joint '{}' is unconnected and it will be removed"
logging.debug(msg=msg.format(joint.name))
# Collect all the links that are not part of the kinematic graph.
unconnected_links = [l for l in links if l.name not in all_link_names_in_graph]
# Update the unconnected links by removing their children. The other properties
# are left untouched, it's caller responsibility to post-process them if needed.
for link in unconnected_links:
link.children = tuple()
msg = "Link '{}' won't be part of the kinematic graph because unconnected"
logging.debug(msg=msg.format(link.name))
# Collect all the frames that are not part of the kinematic graph.
unconnected_frames = [
f for f in frames if f.name not in all_frames_names_in_graph
]
for frame in unconnected_frames:
msg = "Frame '{}' won't be part of the kinematic graph because unconnected"
logging.debug(msg=msg.format(frame.name))
return (
links_dict[root_link_name].mutable(mutable=False),
list(set(joints) - set(removed_joints)),
all_frames_in_graph,
unconnected_links,
list(set(removed_joints)),
unconnected_frames,
)
def reduce(self, considered_joints: Sequence[str]) -> KinematicGraph:
"""
Reduce the kinematic graph by removing unspecified joints.
When a joint is removed, the mass and inertia of its child link are lumped
with those of its parent link, obtaining a new link that combines the two.
The description of the removed joint specifies the default angle (usually 0)
that is considered when the joint is removed.
Args:
considered_joints: A list of joint names to consider.
Returns:
The reduced kinematic graph.
"""
# The current object represents the complete kinematic graph
full_graph = self
# Get the names of the joints to remove
joint_names_to_remove = list(
set(full_graph.joint_names()) - set(considered_joints)
)
# Return early if there is no action to take
if len(joint_names_to_remove) == 0:
logging.info("The kinematic graph doesn't need to be reduced")
return copy.deepcopy(self)
# Check if all considered joints are part of the full kinematic graph
if set(considered_joints) - {j.name for j in full_graph.joints}:
extra_j = set(considered_joints) - {j.name for j in full_graph.joints}
msg = f"Not all joints to consider are part of the graph ({{{extra_j}}})"
raise ValueError(msg)
# Extract data we need to modify from the full graph
links_dict = copy.deepcopy(full_graph.links_dict)
joints_dict = copy.deepcopy(full_graph.joints_dict)
# Create the object to compute forward kinematics.
fk = KinematicGraphTransforms(graph=full_graph)
# The following steps are implemented below in order to create the reduced graph:
#
# 1. Lump the mass of the removed links into their parent
# 2. Update the pose and parent link of joints having the removed link as parent
# 3. Create the reduced graph considering the removed links as frames
# 4. Resolve the pose of the frames wrt their reduced graph parent
#
# We name "removed link" the link to remove, and "lumped link" the new link that
# combines the removed link and its parent. The lumped link will share the frame
# of the removed link's parent and the inertial properties of the two links that
# have been combined.
# =======================================================
# 1. Lump the mass of the removed links into their parent
# =======================================================
# Get all the links to remove. They will be lumped with their parent.
links_to_remove = [
joint.child.name
for joint_name, joint in joints_dict.items()
if joint_name in joint_names_to_remove
]
# Lump the mass and the inertia traversing the tree from the leaf to the root,
# this way we propagate these properties back even in the case when also the
# parent link of a removed joint has to be lumped with its parent.
for link in reversed(full_graph):
if link.name not in links_to_remove:
continue
# Get the link to remove and its parent, i.e. the lumped link
link_to_remove = links_dict[link.name]
parent_of_link_to_remove = links_dict[link.parent_name]
msg = "Lumping chain: {}->({})->{}"
logging.debug(
msg.format(
link_to_remove.name,
self.joints_connection_dict[
parent_of_link_to_remove.name, link_to_remove.name
].name,
parent_of_link_to_remove.name,
)
)
# Lump the link
lumped_link = parent_of_link_to_remove.lump_with(
link=link_to_remove,
lumped_H_removed=fk.relative_transform(
relative_to=parent_of_link_to_remove.name, name=link_to_remove.name
),
)
# Pop the original two links from the dictionary...
_ = links_dict.pop(link_to_remove.name)
_ = links_dict.pop(parent_of_link_to_remove.name)
# ... and insert the lumped link (having the same name of the parent)
links_dict[lumped_link.name] = lumped_link
# Insert back in the dict an entry from the removed link name to the new
# lumped link. We need this info later, when we process the remaining joints.
links_dict[link_to_remove.name] = lumped_link
# As a consequence of the back-insertion, we need to adjust the resulting
# lumped link of links that have been removed previously.
# Note: in the dictionary, only items whose key is not matching value.name
# are links that have been removed.
for previously_removed_link_name in {
link_name
for link_name, link in links_dict.items()
if link_name != link.name and link.name == link_to_remove.name
}:
links_dict[previously_removed_link_name] = lumped_link
# ==============================================================================
# 2. Update the pose and parent link of joints having the removed link as parent
# ==============================================================================
# Find the joints having the removed links as parent
joints_with_removed_parent_link = [
joints_dict[joint_name]
for joint_name in considered_joints
if joints_dict[joint_name].parent.name in links_to_remove
]
# Update the pose of all joints having as parent link a removed link
for joint in joints_with_removed_parent_link:
# Update the pose. Note that after the lumping process, the dict entry
# links_dict[joint.parent.name] contains the final lumped link
with joint.mutable_context(mutability=Mutability.MUTABLE):
joint.pose = fk.relative_transform(
relative_to=links_dict[joint.parent.name].name, name=joint.name
)
with joint.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
# Update the parent link
joint.parent = links_dict[joint.parent.name]
# ===================================================================
# 3. Create the reduced graph considering the removed links as frames
# ===================================================================
# Get all the original links from the full graph
full_graph_links_dict = copy.deepcopy(full_graph.links_dict)
# Get all the final links from the reduced graph
links_to_keep = [
l for link_name, l in links_dict.items() if link_name not in links_to_remove
]
# Override the entries of the full graph with those of the reduced graph.
# Those that are not overridden will become frames.
for link in links_to_keep:
full_graph_links_dict[link.name] = link
# Create the reduced graph data. We pass the full list of links so that those
# that are not part of the graph will be returned as frames.
(
reduced_root_node,
reduced_joints,
reduced_frames,
unconnected_links,
unconnected_joints,
unconnected_frames,
) = KinematicGraph._create_graph(
links=list(full_graph_links_dict.values()),
joints=[joints_dict[joint_name] for joint_name in considered_joints],
root_link_name=full_graph.root.name,
)
assert {f.name for f in self.frames}.isdisjoint(
{f.name for f in unconnected_frames + reduced_frames}
)
for link in unconnected_links:
logging.debug(msg=f"Link '{link.name}' is unconnected and became a frame")
# Create the reduced graph.
reduced_graph = KinematicGraph(
root=reduced_root_node,
joints=reduced_joints,
frames=self.frames + unconnected_links + reduced_frames,
root_pose=full_graph.root_pose,
_joints_removed=(
self._joints_removed
+ unconnected_joints
+ [joints_dict[name] for name in joint_names_to_remove]
),
)
# ================================================================
# 4. Resolve the pose of the frames wrt their reduced graph parent
# ================================================================
# Build a new object to compute FK on the reduced graph.
fk_reduced = KinematicGraphTransforms(graph=reduced_graph)
# We need to adjust the pose of the frames since their parent link
# could have been removed by the reduction process.
for frame in reduced_graph.frames:
# Always find the real parent link of the frame
name_of_new_parent_link = fk_reduced.find_parent_link_of_frame(
name=frame.name
)
assert name_of_new_parent_link in reduced_graph, name_of_new_parent_link
# Notify the user if the parent link has changed.
if name_of_new_parent_link != frame.parent_name:
msg = "New parent of frame '{}' is '{}'"
logging.debug(msg=msg.format(frame.name, name_of_new_parent_link))
# Always recompute the pose of the frame, and set zero inertial params.
with frame.mutable_context(jaxsim.utils.Mutability.MUTABLE_NO_VALIDATION):
# Update kinematic parameters of the frame.
# Note that here we compute the transform using the FK object of the
# full model, so that we are sure that the kinematic is not altered.
frame.pose = fk.relative_transform(
relative_to=name_of_new_parent_link, name=frame.name
)
# Update the parent link such that the pose is expressed in its frame.
frame.parent_name = name_of_new_parent_link
# Update dynamic parameters of the frame.
frame.mass = 0.0
frame.inertia = np.zeros_like(frame.inertia)
# Return the reduced graph.
return reduced_graph
def link_names(self) -> list[str]:
"""
Get the names of all links in the kinematic graph (i.e. the nodes).
Returns:
The list of link names.
"""
return list(self.links_dict.keys())
def joint_names(self) -> list[str]:
"""
Get the names of all joints in the kinematic graph (i.e. the edges).
Returns:
The list of joint names.
"""
return list(self.joints_dict.keys())
def frame_names(self) -> list[str]:
"""
Get the names of all frames in the kinematic graph.
Returns:
The list of frame names.
"""
return list(self.frames_dict.keys())
def print_tree(self) -> None:
"""
Print the tree structure of the kinematic graph.
"""
import pptree
root_node = self.root
pptree.print_tree(
root_node,
childattr="children",
nameattr="name_and_index",
horizontal=True,
)
@property
def joints_removed(self) -> list[JointDescription]:
"""
Get the list of joints removed during the graph reduction.
Returns:
The list of removed joints.
"""
return self._joints_removed
@staticmethod
def breadth_first_search(
root: LinkDescription,
sort_children: Callable[[Any], Any] | None = lambda link: link.name,
) -> Iterable[LinkDescription]:
"""
Perform a breadth-first search (BFS) traversal of the kinematic graph.
Args:
root: The root link for BFS.
sort_children: A function to sort children of a node.
Yields:
The links in the kinematic graph in BFS order.
"""
# Initialize the queue with the root node.
queue = [root]
# We assume that nodes have unique names and mark a link as visited using
# its name. This speeds up considerably object comparison.
visited = []
visited.append(root.name)
yield root
while queue:
# Extract the first element of the queue.
l = queue.pop(0)
# Note: sorting the links with their name so that the order of children
# insertion does not matter when assigning the link index.
for child in sorted(l.children, key=sort_children):
if child.name in visited:
continue
visited.append(child.name)
queue.append(child)
yield child
# =================
# Sequence protocol
# =================
def __iter__(self) -> Iterator[LinkDescription]:
yield from KinematicGraph.breadth_first_search(root=self.root)
def __reversed__(self) -> Iterable[LinkDescription]:
yield from reversed(list(iter(self)))
def __len__(self) -> int:
return len(list(iter(self)))
def __contains__(self, item: str | LinkDescription) -> bool:
if isinstance(item, str):
return item in self.link_names()
if isinstance(item, LinkDescription):
return item in set(iter(self))
raise TypeError(type(item).__name__)
def __getitem__(self, key: int | str) -> LinkDescription:
if isinstance(key, str):
if key not in self.link_names():
raise KeyError(key)
return self.links_dict[key]
if isinstance(key, int):
if key > len(self):
raise KeyError(key)
return list(iter(self))[key]
raise TypeError(type(key).__name__)
def count(self, value: LinkDescription) -> int:
"""
Count the occurrences of a link in the kinematic graph.
"""
return list(iter(self)).count(value)
def index(self, value: LinkDescription, start: int = 0, stop: int = -1) -> int:
"""
Find the index of a link in the kinematic graph.
"""
return list(iter(self)).index(value, start, stop)
# ====================
# Other useful classes
# ====================
@dataclasses.dataclass(frozen=True)
class KinematicGraphTransforms:
"""
Class to compute forward kinematics on a kinematic graph.
Attributes:
graph: The kinematic graph on which to compute forward kinematics.
"""
graph: KinematicGraph
_transform_cache: dict[str, npt.NDArray] = dataclasses.field(
default_factory=dict, init=False, repr=False, compare=False
)
_initial_joint_positions: dict[str, float] = dataclasses.field(
init=False, repr=False, compare=False
)
def __post_init__(self) -> None:
super().__setattr__(
"_initial_joint_positions",
{joint.name: joint.initial_position for joint in self.graph.joints},
)
@property
def initial_joint_positions(self) -> npt.NDArray:
"""
Get the initial joint positions of the kinematic graph.
"""
return np.atleast_1d(
np.array(list(self._initial_joint_positions.values()))
).astype(float)
@initial_joint_positions.setter
def initial_joint_positions(
self,
positions: npt.NDArray | Sequence,
joint_names: Sequence[str] | None = None,
) -> None:
joint_names = (
joint_names
if joint_names is not None
else list(self._initial_joint_positions.keys())
)
s = np.atleast_1d(np.array(positions).squeeze())
if s.size != len(joint_names):
raise ValueError(s.size, len(joint_names))
for joint_name in joint_names:
if joint_name not in self._initial_joint_positions:
raise ValueError(joint_name)
# Clear transform cache.
self._transform_cache.clear()
# Update initial joint positions.
for joint_name, position in zip(joint_names, s, strict=True):
self._initial_joint_positions[joint_name] = position
def transform(self, name: str) -> npt.NDArray:
"""
Compute the SE(3) transform of elements belonging to the kinematic graph.
Args:
name: The name of a link, a joint, or a frame.
Returns:
The 4x4 transform matrix of the element w.r.t. the model frame.
"""
# If the transform was already computed, return it.
if name in self._transform_cache:
return self._transform_cache[name]
# If the name is a joint, compute M_H_J transform.
if name in self.graph.joint_names():
# Get the joint.
joint = self.graph.joints_dict[name]
assert joint.name == name
# Get the transform of the parent link.
M_H_L = self.transform(name=joint.parent.name)
# Rename the pose of the predecessor joint frame w.r.t. its parent link.
L_H_pre = joint.pose
# Compute the joint transform from the predecessor to the successor frame.
pre_H_J = self.pre_H_suc(
joint_type=joint.jtype,
joint_axis=joint.axis,
joint_position=self._initial_joint_positions[joint.name],
)
# Compute the M_H_J transform.
self._transform_cache[name] = M_H_L @ L_H_pre @ pre_H_J
return self._transform_cache[name]
# If the name is a link, compute M_H_L transform.
if name in self.graph.link_names():
# Get the link.
link = self.graph.links_dict[name]
# Handle the pose between the __model__ frame and the root link.
if link.name == self.graph.root.name:
M_H_B = link.pose
return M_H_B
# Get the joint between the link and its parent.
parent_joint = self.graph.joints_connection_dict[
link.parent_name, link.name
]
# Get the transform of the parent joint.
M_H_J = self.transform(name=parent_joint.name)
# Rename the pose of the link w.r.t. its parent joint.
J_H_L = link.pose
# Compute the M_H_L transform.
self._transform_cache[name] = M_H_J @ J_H_L
return self._transform_cache[name]
# It can only be a plain frame.
if name not in self.graph.frame_names():
raise ValueError(name)
# Get the frame.
frame = self.graph.frames_dict[name]
# Get the transform of the parent link.
M_H_L = self.transform(name=frame.parent_name)
# Rename the pose of the frame w.r.t. its parent link.
L_H_F = frame.pose
# Compute the M_H_F transform.
self._transform_cache[name] = M_H_L @ L_H_F
return self._transform_cache[name]
def relative_transform(self, relative_to: str, name: str) -> npt.NDArray:
"""
Compute the SE(3) relative transform of elements belonging to the kinematic graph.
Args:
relative_to: The name of the reference element.
name: The name of a link, a joint, or a frame.
Returns:
The 4x4 transform matrix of the element w.r.t. the desired frame.
"""
import jaxsim.math
M_H_target = self.transform(name=name)
M_H_R = self.transform(name=relative_to)
# Compute the relative transform R_H_target, where R is the reference frame,
# and i the frame of the desired link|joint|frame.
return np.array(jaxsim.math.Transform.inverse(M_H_R)) @ M_H_target
@staticmethod
def pre_H_suc(
joint_type: JointType,
joint_axis: npt.NDArray,
joint_position: float | None = None,
) -> npt.NDArray:
"""
Compute the SE(3) transform from the predecessor to the successor frame.
Args:
joint_type: The type of the joint.
joint_axis: The axis of the joint.
joint_position: The position of the joint.
Returns:
The 4x4 transform matrix from the predecessor to the successor frame.
"""
import jaxsim.math
return np.array(
jaxsim.math.supported_joint_motion(joint_type, joint_position, joint_axis)
)
def find_parent_link_of_frame(self, name: str) -> str:
"""
Find the parent link of a frame.
Args:
name: The name of the frame.
Returns:
The name of the parent link of the frame.
"""
try:
frame = self.graph.frames_dict[name]
except KeyError as e:
raise ValueError(f"Frame '{name}' not found in the kinematic graph") from e
if frame.parent_name in self.graph.links_dict:
return frame.parent_name
if frame.parent_name in self.graph.frames_dict:
return self.find_parent_link_of_frame(name=frame.parent_name)
msg = f"Failed to find parent element of frame '{name}' with name '{frame.parent_name}'"
raise RuntimeError(msg)
================================================
FILE: src/jaxsim/parsers/rod/__init__.py
================================================
from . import parser, utils
from .parser import build_model_description, extract_model_data
================================================
FILE: src/jaxsim/parsers/rod/meshes.py
================================================
import numpy as np
import trimesh
VALID_AXIS = {"x": 0, "y": 1, "z": 2}
def extract_points_vertices(mesh: trimesh.Trimesh) -> np.ndarray:
"""
Extract the vertices of a mesh as points.
"""
return mesh.vertices
def extract_points_random_surface_sampling(mesh: trimesh.Trimesh, n) -> np.ndarray:
"""
Extract N random points from the surface of a mesh.
Args:
mesh: The mesh from which to extract points.
n: The number of points to extract.
Returns:
The extracted points (N x 3 array).
"""
return mesh.sample(n)
def extract_points_uniform_surface_sampling(
mesh: trimesh.Trimesh, n: int
) -> np.ndarray:
"""
Extract N uniformly sampled points from the surface of a mesh.
Args:
mesh: The mesh from which to extract points.
n: The number of points to extract.
Returns:
The extracted points (N x 3 array).
"""
return trimesh.sample.sample_surface_even(mesh=mesh, count=n)[0]
def extract_points_select_points_over_axis(
mesh: trimesh.Trimesh, axis: str, direction: str, n: int
) -> np.ndarray:
"""
Extract N points from a mesh along a specified axis. The points are selected based on their position along the axis.
Args:
mesh: The mesh from which to extract points.
axis: The axis along which to extract points.
direction: The direction along the axis from which to extract points. Valid values are "higher" and "lower".
n: The number of points to extract.
Returns:
The extracted points (N x 3 array).
"""
dirs = {"higher": np.s_[-n:], "lower": np.s_[:n]}
arr = mesh.vertices
# Sort rows lexicographically first, then columnar.
arr.sort(axis=0)
sorted_arr = arr[dirs[direction]]
return sorted_arr
def extract_points_aap(
mesh: trimesh.Trimesh,
axis: str,
upper: float | None = None,
lower: float | None = None,
) -> np.ndarray:
"""
Extract points from a mesh along a specified axis within a specified range. The points are selected based on their position along the axis.
Args:
mesh: The mesh from which to extract points.
axis: The axis along which to extract points.
upper: The upper bound of the range.
lower: The lower bound of the range.
Returns:
The extracted points (N x 3 array).
Raises:
AssertionError: If the lower bound is greater than the upper bound.
"""
# Check bounds.
upper = upper if upper is not None else np.inf
lower = lower if lower is not None else -np.inf
assert lower < upper, "Invalid bounds for axis-aligned plane"
# Logic.
points = mesh.vertices[
(mesh.vertices[:, VALID_AXIS[axis]] >= lower)
& (mesh.vertices[:, VALID_AXIS[axis]] <= upper)
]
return points
================================================
FILE: src/jaxsim/parsers/rod/parser.py
================================================
import dataclasses
import os
import pathlib
from typing import NamedTuple
import jax.numpy as jnp
import numpy as np
import rod
from jaxsim import logging
from jaxsim.math import Quaternion
from jaxsim.parsers import descriptions, kinematic_graph
from . import utils
class SDFData(NamedTuple):
"""
Data extracted from an SDF resource useful to build a JaxSim model.
"""
model_name: str
fixed_base: bool
base_link_name: str
link_descriptions: list[descriptions.LinkDescription]
joint_descriptions: list[descriptions.JointDescription]
frame_descriptions: list[descriptions.LinkDescription]
collision_shapes: list[descriptions.CollisionShape]
sdf_model: rod.Model | None = None
model_pose: kinematic_graph.RootPose = kinematic_graph.RootPose()
def extract_model_data(
model_description: pathlib.Path | str | rod.Model | rod.Sdf,
model_name: str | None = None,
is_urdf: bool | None = None,
) -> SDFData:
"""
Extract data from an SDF/URDF resource useful to build a JaxSim model.
Args:
model_description:
A path to an SDF/URDF file, a string containing its content, or
a pre-parsed/pre-built rod model.
model_name: The name of the model to extract from the SDF resource.
is_urdf:
Whether to force parsing the resource as a URDF file. Automatically
detected if not provided.
Returns:
The extracted model data.
"""
match model_description:
case rod.Model():
sdf_model = model_description
case rod.Sdf() | str() | pathlib.Path():
sdf_element = (
model_description
if isinstance(model_description, rod.Sdf)
else rod.Sdf.load(sdf=model_description, is_urdf=is_urdf)
)
if not sdf_element.models():
raise RuntimeError("Failed to find any model in SDF resource")
# Assume the SDF resource has only one model, or the desired model name is given.
sdf_models = {m.name: m for m in sdf_element.models()}
sdf_model = (
sdf_element.models()[0]
if len(sdf_models) == 1
else sdf_models[model_name]
)
# Log model name.
logging.info(msg=f"Found model '{sdf_model.name}' in SDF resource")
# Jaxsim supports only models compatible with URDF, i.e. those having all links
# directly attached to their parent joint without additional roto-translations.
# Furthermore, the following switch also post-processes frames such that their
# pose is expressed wrt the parent link they are rigidly attached to.
sdf_model.switch_frame_convention(frame_convention=rod.FrameConvention.Urdf)
# Log type of base link.
logging.debug(
msg=f"Model '{sdf_model.name}' is {'fixed-base' if sdf_model.is_fixed_base() else 'floating-base'}"
)
# Log detected base link.
logging.debug(msg=f"Considering '{sdf_model.get_canonical_link()}' as base link")
# Pose of the model
if sdf_model.pose is None:
model_pose = kinematic_graph.RootPose()
else:
W_H_M = sdf_model.pose.transform()
model_pose = kinematic_graph.RootPose(
root_position=W_H_M[0:3, 3],
root_quaternion=Quaternion.from_dcm(dcm=W_H_M[0:3, 0:3]),
)
# ===========
# Parse links
# ===========
# Parse the links (unconnected).
links = [
descriptions.LinkDescription(
name=l.name,
mass=float(l.inertial.mass),
inertia=utils.from_sdf_inertial(inertial=l.inertial),
pose=l.pose.transform() if l.pose is not None else np.eye(4),
)
for l in sdf_model.links()
if l.inertial.mass > 0
]
# Create a dictionary to find easily links.
links_dict: dict[str, descriptions.LinkDescription] = {l.name: l for l in links}
# ============
# Parse frames
# ============
# Parse the frames (unconnected).
frames = [
descriptions.LinkDescription(
name=f.name,
mass=jnp.array(0.0, dtype=float),
inertia=jnp.zeros(shape=(3, 3)),
parent_name=f.attached_to,
pose=f.pose.transform() if f.pose is not None else jnp.eye(4),
)
for f in sdf_model.frames()
if f.attached_to in links_dict
]
# =========================
# Process fixed-base models
# =========================
# In this case, we need to get the pose of the joint that connects the base link
# to the world and combine their pose.
if sdf_model.is_fixed_base():
# Create a massless word link
world_link = descriptions.LinkDescription(
name="world", mass=0, inertia=np.zeros(shape=(6, 6))
)
# Gather joints connecting fixed-base models to the world.
# TODO: the pose of this joint could be expressed wrt any arbitrary frame,
# here we assume is expressed wrt the model. This also means that the
# default model pose matches the pose of the fake "world" link.
joints_with_world_parent = [
descriptions.JointDescription(
name=j.name,
parent=world_link,
child=links_dict[j.child],
jtype=utils.joint_to_joint_type(joint=j),
axis=(
np.array(j.axis.xyz.xyz)
if j.axis is not None
and j.axis.xyz is not None
and j.axis.xyz.xyz is not None
else None
),
pose=j.pose.transform() if j.pose is not None else np.eye(4),
)
for j in sdf_model.joints()
if j.type == "fixed"
and j.parent == "world"
and j.child in links_dict
and j.pose.relative_to in {"__model__", "world", None}
]
logging.debug(
f"Found joints connecting to world: {[j.name for j in joints_with_world_parent]}"
)
if len(joints_with_world_parent) != 1:
msg = "Found more/less than one joint connecting a fixed-base model to the world"
raise ValueError(msg + f": {[j.name for j in joints_with_world_parent]}")
base_link_name = joints_with_world_parent[0].child.name
msg = "Combining the pose of base link '{}' with the pose of joint '{}'"
logging.debug(msg.format(base_link_name, joints_with_world_parent[0].name))
# Combine the pose of the base link (child of the found fixed joint)
# with the pose of the fixed joint connecting with the world.
# Note: we assume it's a fixed joint and ignore any joint angle.
links_dict[base_link_name].mutable(validate=False).pose = (
joints_with_world_parent[0].pose @ links_dict[base_link_name].pose
)
# ============
# Parse joints
# ============
# Check that all joint poses are expressed w.r.t. their parent link.
for j in sdf_model.joints():
if j.pose is None:
continue
if j.parent == "world":
if j.pose.relative_to in {"__model__", "world", None}:
continue
raise ValueError("Pose of fixed joint connecting to 'world' link not valid")
if j.pose.relative_to != j.parent:
msg = "Pose of joint '{}' is not expressed wrt its parent link '{}'"
raise ValueError(msg.format(j.name, j.parent))
# Parse the joints.
joints = [
descriptions.JointDescription(
name=j.name,
parent=links_dict[j.parent],
child=links_dict[j.child],
jtype=utils.joint_to_joint_type(joint=j),
axis=(
np.array(j.axis.xyz.xyz, dtype=float)
if j.axis is not None
and j.axis.xyz is not None
and j.axis.xyz.xyz is not None
else None
),
pose=j.pose.transform() if j.pose is not None else np.eye(4),
initial_position=0.0,
position_limit=(
float(
j.axis.limit.lower
if j.axis is not None
and j.axis.limit is not None
and j.axis.limit.lower is not None
else jnp.finfo(float).min
),
float(
j.axis.limit.upper
if j.axis is not None
and j.axis.limit is not None
and j.axis.limit.upper is not None
else jnp.finfo(float).max
),
),
friction_static=float(
j.axis.dynamics.friction
if j.axis is not None
and j.axis.dynamics is not None
and j.axis.dynamics.friction is not None
else 0.0
),
friction_viscous=float(
j.axis.dynamics.damping
if j.axis is not None
and j.axis.dynamics is not None
and j.axis.dynamics.damping is not None
else 0.0
),
position_limit_damper=float(
j.axis.limit.dissipation
if j.axis is not None
and j.axis.limit is not None
and j.axis.limit.dissipation is not None
else os.environ.get("JAXSIM_JOINT_POSITION_LIMIT_DAMPER", 0.0)
),
position_limit_spring=float(
j.axis.limit.stiffness
if j.axis is not None
and j.axis.limit is not None
and j.axis.limit.stiffness is not None
else os.environ.get("JAXSIM_JOINT_POSITION_LIMIT_SPRING", 0.0)
),
)
for j in sdf_model.joints()
if j.type in {"revolute", "continuous", "prismatic", "fixed"}
and j.parent != "world"
and j.child in links_dict
]
# Create a dictionary to find the parent joint of the links.
joint_dict = {j.child.name: j.name for j in joints}
# Check that all the link poses are expressed wrt their parent joint.
for l in sdf_model.links():
if l.name not in links_dict:
continue
if l.pose is None:
continue
if l.name == sdf_model.get_canonical_link():
continue
if l.name not in joint_dict:
raise ValueError(f"Failed to find parent joint of link '{l.name}'")
if l.pose.relative_to != joint_dict[l.name]:
msg = "Pose of link '{}' is not expressed wrt its parent joint '{}'"
raise ValueError(msg.format(l.name, joint_dict[l.name]))
# ================
# Parse collisions
# ================
# Initialize the collision shapes
collisions: list[descriptions.CollisionShape] = []
# Parse the collisions
for link in sdf_model.links():
for collision in link.collisions():
if collision.geometry.box is not None:
box_collision = utils.create_box_collision(
collision=collision,
link_description=links_dict[link.name],
)
collisions.append(box_collision)
continue
if collision.geometry.sphere is not None:
sphere_collision = utils.create_sphere_collision(
collision=collision,
link_description=links_dict[link.name],
)
collisions.append(sphere_collision)
continue
if collision.geometry.mesh is not None:
if int(os.environ.get("JAXSIM_COLLISION_MESH_ENABLED", "0")):
logging.warning("Mesh collision support is still experimental.")
mesh_collision = utils.create_mesh_collision(
collision=collision,
link_description=links_dict[link.name],
method=utils.meshes.extract_points_vertices,
)
collisions.append(mesh_collision)
else:
logging.warning(
f"Skipping collision shape 'mesh' in link '{link.name}' because mesh collisions are disabled."
)
continue
# Check any remaining non-None geometry types.
for attr_name in collision.geometry.__dict__:
if getattr(collision.geometry, attr_name) is not None:
logging.warning(
f"Skipping collision shape '{attr_name}' in link '{link.name}' as not supported."
)
return SDFData(
model_name=sdf_model.name,
link_descriptions=links,
joint_descriptions=joints,
frame_descriptions=frames,
collision_shapes=collisions,
fixed_base=sdf_model.is_fixed_base(),
base_link_name=sdf_model.get_canonical_link(),
model_pose=model_pose,
sdf_model=sdf_model,
)
def build_model_description(
model_description: pathlib.Path | str | rod.Model,
is_urdf: bool | None = None,
) -> descriptions.ModelDescription:
"""
Build a model description from an SDF/URDF resource.
Args:
model_description: A path to an SDF/URDF file, a string containing its content,
or a pre-parsed/pre-built rod model.
is_urdf: Whether the force parsing the resource as a URDF file. Automatically
detected if not provided.
Returns:
The parsed model description.
"""
# Parse data from the SDF assuming it contains a single model.
sdf_data = extract_model_data(
model_description=model_description, model_name=None, is_urdf=is_urdf
)
# Build the intermediate representation used for building a JaxSim model.
# This process, beyond other operations, removes the fixed joints.
# Note: if the model is fixed-base, the fixed joint between world and the first
# link is removed and the pose of the first link is updated.
#
# The whole process is:
# URDF/SDF ⟶ rod.Model ⟶ ModelDescription ⟶ JaxSimModel.
graph = descriptions.ModelDescription.build_model_from(
name=sdf_data.model_name,
links=sdf_data.link_descriptions,
joints=sdf_data.joint_descriptions,
frames=sdf_data.frame_descriptions,
collisions=sdf_data.collision_shapes,
fixed_base=sdf_data.fixed_base,
base_link_name=sdf_data.base_link_name,
model_pose=sdf_data.model_pose,
considered_joints=[
j.name
for j in sdf_data.joint_descriptions
if j.jtype is not descriptions.JointType.Fixed
],
)
# Store the parsed SDF tree as extra info
graph = dataclasses.replace(graph, _extra_info={"sdf_model": sdf_data.sdf_model})
return graph
================================================
FILE: src/jaxsim/parsers/rod/utils.py
================================================
import os
import pathlib
from collections.abc import Callable
from typing import TypeVar
import numpy as np
import numpy.typing as npt
import rod
import trimesh
from rod.utils.resolve_uris import resolve_local_uri
import jaxsim.typing as jtp
from jaxsim import logging
from jaxsim.math import Adjoint, Inertia
from jaxsim.parsers import descriptions
from jaxsim.parsers.rod import meshes
MeshMappingMethod = TypeVar("MeshMappingMethod", bound=Callable[..., npt.NDArray])
def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
"""
Extract the 6D inertia matrix from an SDF inertial element.
Args:
inertial: The SDF inertial element.
Returns:
The 6D inertia matrix of the link expressed in the link frame.
"""
# Extract the "mass" element.
m = inertial.mass
# Extract the "inertia" element.
inertia_element = inertial.inertia
ixx = inertia_element.ixx
iyy = inertia_element.iyy
izz = inertia_element.izz
ixy = inertia_element.ixy if inertia_element.ixy is not None else 0.0
ixz = inertia_element.ixz if inertia_element.ixz is not None else 0.0
iyz = inertia_element.iyz if inertia_element.iyz is not None else 0.0
# Build the 3x3 inertia matrix expressed in the CoM.
I_CoM = np.array(
[
[ixx, ixy, ixz],
[ixy, iyy, iyz],
[ixz, iyz, izz],
]
)
# Build the 6x6 generalized inertia at the CoM.
M_CoM = Inertia.to_sixd(mass=m, com=np.zeros(3), I=I_CoM)
# Compute the transform from the inertial frame (CoM) to the link frame.
L_H_CoM = inertial.pose.transform() if inertial.pose is not None else np.eye(4)
# We need its inverse.
CoM_X_L = Adjoint.from_transform(transform=L_H_CoM, inverse=True)
# Express the CoM inertia matrix in the link frame L.
M_L = CoM_X_L.T @ M_CoM @ CoM_X_L
return M_L.astype(dtype=float)
def joint_to_joint_type(joint: rod.Joint) -> int:
"""
Extract the joint type from an SDF joint.
Args:
joint: The parsed SDF joint.
Returns:
The integer corresponding to the joint type.
"""
axis = joint.axis
joint_type = joint.type
if joint_type == "fixed":
return descriptions.JointType.Fixed
if not (axis.xyz is not None and axis.xyz.xyz is not None):
raise ValueError("Failed to read axis xyz data")
# Make sure that the axis is a unary vector.
axis_xyz = np.array(axis.xyz.xyz).astype(float)
axis_xyz = axis_xyz / np.linalg.norm(axis_xyz)
if joint_type in {"revolute", "continuous"}:
return descriptions.JointType.Revolute
if joint_type == "prismatic":
return descriptions.JointType.Prismatic
raise ValueError("Joint not supported", axis_xyz, joint_type)
def create_box_collision(
collision: rod.Collision, link_description: descriptions.LinkDescription
) -> descriptions.BoxCollision:
"""
Create a box collision from an SDF collision element.
Args:
collision: The SDF collision element.
link_description: The link description.
Returns:
The box collision description.
"""
x, y, z = collision.geometry.box.size
center = np.array([x / 2, y / 2, z / 2])
# Define the bottom corners.
bottom_corners = np.array([[0, 0, 0], [x, 0, 0], [x, y, 0], [0, y, 0]])
# Conditionally add the top corners based on the environment variable.
top_corners = (
np.array([[0, 0, z], [x, 0, z], [x, y, z], [0, y, z]])
if os.environ.get("JAXSIM_COLLISION_USE_BOTTOM_ONLY", "0").lower()
in {
"false",
"0",
}
else []
)
# Combine and shift by the center
box_corners = np.vstack([bottom_corners, *top_corners]) - center
H = collision.pose.transform() if collision.pose is not None else np.eye(4)
center_wrt_link = (H @ np.hstack([center, 1.0]))[0:-1]
box_corners_wrt_link = (
H @ np.hstack([box_corners, np.vstack([1.0] * box_corners.shape[0])]).T
)[0:3, :]
collidable_points = [
descriptions.CollidablePoint(
parent_link=link_description,
position=np.array(corner),
enabled=True,
)
for corner in box_corners_wrt_link.T
]
return descriptions.BoxCollision(
collidable_points=collidable_points, center=center_wrt_link
)
def create_sphere_collision(
collision: rod.Collision, link_description: descriptions.LinkDescription
) -> descriptions.SphereCollision:
"""
Create a sphere collision from an SDF collision element.
Args:
collision: The SDF collision element.
link_description: The link description.
Returns:
The sphere collision description.
"""
# From https://stackoverflow.com/a/26127012
def fibonacci_sphere(samples: int) -> npt.NDArray:
# Get the golden ratio in radians.
phi = np.pi * (3.0 - np.sqrt(5.0))
# Generate the points.
points = [
np.array(
[
np.cos(phi * i)
* np.sqrt(1 - (y := 1 - 2 * i / (samples - 1)) ** 2),
y,
np.sin(phi * i) * np.sqrt(1 - y**2),
]
)
for i in range(samples)
]
# Filter to keep only the bottom half if required.
if os.environ.get("JAXSIM_COLLISION_USE_BOTTOM_ONLY", "0").lower() in {
"true",
"1",
}:
# Keep only the points with z <= 0.
points = [point for point in points if point[2] <= 0]
return np.vstack(points)
r = collision.geometry.sphere.radius
sphere_points = r * fibonacci_sphere(
samples=int(os.getenv(key="JAXSIM_COLLISION_SPHERE_POINTS", default="50"))
)
H = collision.pose.transform() if collision.pose is not None else np.eye(4)
center_wrt_link = (H @ np.hstack([0, 0, 0, 1.0]))[0:-1]
sphere_points_wrt_link = (
H @ np.hstack([sphere_points, np.vstack([1.0] * sphere_points.shape[0])]).T
)[0:3, :]
collidable_points = [
descriptions.CollidablePoint(
parent_link=link_description,
position=np.array(point),
enabled=True,
)
for point in sphere_points_wrt_link.T
]
return descriptions.SphereCollision(
collidable_points=collidable_points, center=center_wrt_link
)
def create_mesh_collision(
collision: rod.Collision,
link_description: descriptions.LinkDescription,
method: MeshMappingMethod = None,
) -> descriptions.MeshCollision:
"""
Create a mesh collision from an SDF collision element.
Args:
collision: The SDF collision element.
link_description: The link description.
method: The method to use for mesh wrapping.
Returns:
The mesh collision description.
"""
file = pathlib.Path(resolve_local_uri(uri=collision.geometry.mesh.uri))
file_type = file.suffix.replace(".", "")
mesh = trimesh.load_mesh(file, file_type=file_type)
if mesh.is_empty:
raise RuntimeError(f"Failed to process '{file}' with trimesh")
mesh.apply_scale(collision.geometry.mesh.scale)
logging.info(
msg=f"Loading mesh {collision.geometry.mesh.uri} with scale {collision.geometry.mesh.scale}, file type '{file_type}'"
)
if method is None:
method = meshes.VertexExtraction()
logging.debug("Using default Vertex Extraction method for mesh wrapping")
else:
logging.debug(f"Using method {method} for mesh wrapping")
points = method(mesh=mesh)
logging.debug(f"Extracted {len(points)} points from mesh")
W_H_L = collision.pose.transform() if collision.pose is not None else np.eye(4)
# Extract translation from transformation matrix
W_p_L = W_H_L[:3, 3]
mesh_points_wrt_link = points @ W_H_L[:3, :3].T + W_p_L
collidable_points = [
descriptions.CollidablePoint(
parent_link=link_description,
position=point,
enabled=True,
)
for point in mesh_points_wrt_link
]
return descriptions.MeshCollision(collidable_points=collidable_points, center=W_p_L)
def prepare_mesh_for_parametrization(
mesh_uri: str, scale: tuple[float, float, float] = (1.0, 1.0, 1.0)
) -> dict:
"""
Load and prepare a mesh for parametric scaling with exact inertia computation.
This function loads a mesh, ensures it's watertight (crucial for volume/inertia
calculation), centers it, and returns the data needed for parametric scaling.
Args:
mesh_uri: URI/path to the mesh file.
scale: Initial scale factors to apply (from SDF/URDF).
Returns:
A dictionary containing:
- 'vertices': Centered mesh vertices as numpy array (Nx3)
- 'faces': Triangle faces as numpy array (Mx3 integer indices)
- 'offset': Original mesh centroid offset as numpy array (3,)
- 'uri': The mesh URI for reference
- 'is_watertight': Boolean indicating if mesh is watertight
- 'volume': The volume of the mesh (after scaling)
"""
# Load mesh
file = pathlib.Path(resolve_local_uri(uri=mesh_uri))
file_type = file.suffix.replace(".", "")
mesh = trimesh.load_mesh(file, file_type=file_type)
if mesh.is_empty:
raise RuntimeError(f"Failed to process '{file}' with trimesh")
# Apply initial scale from SDF/URDF
mesh.apply_scale(scale)
# Check and fix watertightness
is_watertight = mesh.is_watertight
if not is_watertight:
logging.warning(
f"Mesh {mesh_uri} is not watertight. Computing convex hull for valid inertia."
)
mesh = mesh.convex_hull
is_watertight = True
# Store original centroid as offset
offset = mesh.centroid.copy()
# Center the mesh
mesh.vertices -= offset
return {
"vertices": np.array(mesh.vertices, dtype=np.float64),
"faces": np.array(mesh.faces, dtype=np.int32),
"offset": np.array(offset, dtype=np.float64),
"uri": mesh_uri,
"is_watertight": is_watertight,
"volume": mesh.volume,
}
================================================
FILE: src/jaxsim/rbda/__init__.py
================================================
from . import actuation, contacts
from .aba import aba
from .aba_parallel import aba_parallel
from .collidable_points import collidable_points_pos_vel
from .crba import crba
from .forward_kinematics import forward_kinematics_model
from .forward_kinematics_parallel import forward_kinematics_model_parallel
from .jacobian import (
jacobian,
jacobian_derivative_full_doubly_left,
jacobian_full_doubly_left,
)
from .kinematic_constraints import compute_constraint_wrenches
from .mass_inverse import mass_inverse
from .rnea import rnea
================================================
FILE: src/jaxsim/rbda/aba.py
================================================
import jax
import jax.numpy as jnp
import jaxlie
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import STANDARD_GRAVITY, Adjoint, Cross
from . import utils
def aba(
model: js.model.JaxSimModel,
*,
base_position: jtp.VectorLike,
base_quaternion: jtp.VectorLike,
joint_positions: jtp.VectorLike,
base_linear_velocity: jtp.VectorLike,
base_angular_velocity: jtp.VectorLike,
joint_velocities: jtp.VectorLike,
joint_transforms: jtp.MatrixLike,
joint_forces: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
standard_gravity: jtp.FloatLike = STANDARD_GRAVITY,
) -> tuple[jtp.Vector, jtp.Vector]:
"""
Compute forward dynamics using the Articulated Body Algorithm (ABA).
Args:
model: The model to consider.
base_position: The position of the base link.
base_quaternion: The quaternion of the base link.
joint_positions: The positions of the joints.
base_linear_velocity:
The linear velocity of the base link in inertial-fixed representation.
base_angular_velocity:
The angular velocity of the base link in inertial-fixed representation.
joint_velocities: The velocities of the joints.
joint_transforms: The parent-to-child transforms of the joints.
joint_forces: The forces applied to the joints.
link_forces:
The forces applied to the links expressed in the world frame.
standard_gravity: The standard gravity constant.
Returns:
A tuple containing the base acceleration in inertial-fixed representation
and the joint accelerations that result from the applications of the given
joint and link forces.
Note:
The algorithm expects a quaternion with unit norm.
"""
W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, τ, W_f, W_g = utils.process_inputs(
model=model,
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
base_linear_velocity=base_linear_velocity,
base_angular_velocity=base_angular_velocity,
joint_velocities=joint_velocities,
base_linear_acceleration=None,
base_angular_acceleration=None,
joint_accelerations=None,
joint_forces=joint_forces,
link_forces=link_forces,
standard_gravity=standard_gravity,
)
W_g = jnp.atleast_2d(W_g).T
W_v_WB = jnp.atleast_2d(W_v_WB).T
# Get the 6D spatial inertia matrices of all links.
M = js.model.link_spatial_inertia_matrices(model=model)
# Get the parent array λ(i).
# Note: λ(0) must not be used, it's initialized to -1.
λ = model.kin_dyn_parameters.parent_array
# Compute the base transform.
W_H_B = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3(wxyz=W_Q_B),
translation=W_p_B,
)
# Compute 6D transforms of the base velocity.
W_X_B = W_H_B.adjoint()
B_X_W = W_H_B.inverse().adjoint()
# Extract the parent-to-child adjoints of the joints.
i_X_λi = jnp.asarray(joint_transforms)
# Extract the joint motion subspaces.
S = model.kin_dyn_parameters.motion_subspaces
# Allocate buffers.
v = jnp.zeros(shape=(model.number_of_links(), 6, 1))
c = jnp.zeros(shape=(model.number_of_links(), 6, 1))
pA = jnp.zeros(shape=(model.number_of_links(), 6, 1))
MA = jnp.zeros(shape=(model.number_of_links(), 6, 6))
# Allocate the buffer of transforms link -> base.
i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
i_X_0 = i_X_0.at[0].set(jnp.eye(6))
# Initialize base quantities.
if model.floating_base():
# Base velocity v₀ in body-fixed representation.
v_0 = B_X_W @ W_v_WB
v = v.at[0].set(v_0)
# Initialize the articulated-body inertia (Mᴬ) of base link.
MA_0 = M[0]
MA = MA.at[0].set(MA_0)
# Initialize the articulated-body bias force (pᴬ) of the base link.
pA_0 = Cross.vx_star(v[0]) @ MA[0] @ v[0] - W_X_B.T @ jnp.vstack(W_f[0])
pA = pA.at[0].set(pA_0)
# ======
# Pass 1
# ======
Pass1Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
pass_1_carry: Pass1Carry = (v, c, MA, pA, i_X_0)
# Propagate kinematics and initialize AB inertia and AB bias forces.
def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]:
ii = i - 1
v, c, MA, pA, i_X_0 = carry
# Project the joint velocity into its motion subspace.
vJ = S[i] * ṡ[ii]
# Propagate the link velocity.
v_i = i_X_λi[i] @ v[λ[i]] + vJ
v = v.at[i].set(v_i)
c_i = Cross.vx(v[i]) @ vJ
c = c.at[i].set(c_i)
# Initialize the articulated-body inertia.
MA_i = jnp.array(M[i])
MA = MA.at[i].set(MA_i)
# Compute the link-to-base transform.
i_Xi_0 = i_X_λi[i] @ i_X_0[λ[i]]
i_X_0 = i_X_0.at[i].set(i_Xi_0)
# Compute link-to-world transform for the 6D force.
i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
# Initialize articulated-body bias force.
pA_i = Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(W_f[i])
pA = pA.at[i].set(pA_i)
return (v, c, MA, pA, i_X_0), None
(v, c, MA, pA, i_X_0), _ = (
jax.lax.scan(
f=loop_body_pass1,
init=pass_1_carry,
xs=jnp.arange(start=1, stop=model.number_of_links()),
)
if model.number_of_links() > 1
else [(v, c, MA, pA, i_X_0), None]
)
# ======
# Pass 2
# ======
U = jnp.zeros_like(S)
d = jnp.zeros(shape=(model.number_of_links(), 1))
u = jnp.zeros(shape=(model.number_of_links(), 1))
Pass2Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
pass_2_carry: Pass2Carry = (U, d, u, MA, pA)
def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]:
ii = i - 1
U, d, u, MA, pA = carry
U_i = MA[i] @ S[i]
U = U.at[i].set(U_i)
d_i = S[i].T @ U[i]
d = d.at[i].set(d_i.squeeze())
u_i = τ[ii] - S[i].T @ pA[i]
u = u.at[i].set(u_i.squeeze())
# Compute the articulated-body inertia and bias force of this link.
Ma = MA[i] - U[i] / d[i] @ U[i].T
pa = pA[i] + Ma @ c[i] + U[i] * (u[i] / d[i])
# Propagate them to the parent, handling the base link.
def propagate(
MA_pA: tuple[jtp.Matrix, jtp.Matrix],
) -> tuple[jtp.Matrix, jtp.Matrix]:
MA, pA = MA_pA
MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]
MA = MA.at[λ[i]].set(MA_λi)
pA_λi = pA[λ[i]] + i_X_λi[i].T @ pa
pA = pA.at[λ[i]].set(pA_λi)
return MA, pA
MA, pA = jax.lax.cond(
pred=jnp.logical_or(λ[i] != 0, model.floating_base()),
true_fun=propagate,
false_fun=lambda MA_pA: MA_pA,
operand=(MA, pA),
)
return (U, d, u, MA, pA), None
(U, d, u, MA, pA), _ = (
jax.lax.scan(
f=loop_body_pass2,
init=pass_2_carry,
xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
)
if model.number_of_links() > 1
else [(U, d, u, MA, pA), None]
)
# ======
# Pass 3
# ======
if model.floating_base():
a0 = jnp.linalg.solve(-MA[0], pA[0])
else:
a0 = -B_X_W @ W_g
s̈ = jnp.zeros_like(s)
a = jnp.zeros_like(v).at[0].set(a0)
Pass3Carry = tuple[jtp.Matrix, jtp.Vector]
pass_3_carry = (a, s̈)
def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]:
ii = i - 1
a, s̈ = carry
# Propagate the link acceleration.
a_i = i_X_λi[i] @ a[λ[i]] + c[i]
# Compute the joint acceleration.
s̈_ii = (u[i] - U[i].T @ a_i) / d[i]
s̈ = s̈.at[ii].set(s̈_ii.squeeze())
# Sum the joint acceleration to the parent link acceleration.
a_i = a_i + S[i] * s̈[ii]
a = a.at[i].set(a_i)
return (a, s̈), None
(a, s̈), _ = (
jax.lax.scan(
f=loop_body_pass3,
init=pass_3_carry,
xs=jnp.arange(1, model.number_of_links()),
)
if model.number_of_links() > 1
else [(a, s̈), None]
)
# ==============
# Adjust outputs
# ==============
# TODO: remove vstack and shape=(6, 1)?
if model.floating_base():
# Convert the base acceleration to inertial-fixed representation,
# and add gravity.
B_a_WB = a[0]
W_a_WB = W_X_B @ B_a_WB + W_g
else:
W_a_WB = jnp.zeros(6)
return W_a_WB.squeeze(), jnp.atleast_1d(s̈.squeeze())
================================================
FILE: src/jaxsim/rbda/aba_parallel.py
================================================
import math
import jax
import jax.numpy as jnp
import jaxlie
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import STANDARD_GRAVITY, Adjoint, Cross
from . import utils
def aba_parallel(
model: js.model.JaxSimModel,
*,
base_position: jtp.VectorLike,
base_quaternion: jtp.VectorLike,
joint_positions: jtp.VectorLike,
base_linear_velocity: jtp.VectorLike,
base_angular_velocity: jtp.VectorLike,
joint_velocities: jtp.VectorLike,
joint_transforms: jtp.MatrixLike,
joint_forces: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
standard_gravity: jtp.FloatLike = STANDARD_GRAVITY,
) -> tuple[jtp.Vector, jtp.Vector]:
"""
Compute forward dynamics using a hybrid parallel ABA.
Passes 1 and 3 use pointer jumping in O(log D) parallel steps.
Pass 2 uses level-parallel processing in O(D) steps because the backward
inertia accumulation is not associative.
The interface and semantics are identical to :func:`aba`,
but passes 1 and 3 are parallelized via pointer jumping.
"""
W_p_B, W_Q_B, _, W_v_WB, ṡ, _, _, τ, W_f, W_g = utils.process_inputs(
model=model,
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
base_linear_velocity=base_linear_velocity,
base_angular_velocity=base_angular_velocity,
joint_velocities=joint_velocities,
base_linear_acceleration=None,
base_angular_acceleration=None,
joint_accelerations=None,
joint_forces=joint_forces,
link_forces=link_forces,
standard_gravity=standard_gravity,
)
W_g = jnp.atleast_2d(W_g).T
W_v_WB = jnp.atleast_2d(W_v_WB).T
# Get the 6D spatial inertia matrices of all links.
M = js.model.link_spatial_inertia_matrices(model=model)
# Get the parent array λ(i).
λ = model.kin_dyn_parameters.parent_array
# Get the tree level structure for level-parallel processing.
level_nodes = jnp.asarray(model.kin_dyn_parameters.level_nodes)
level_mask = jnp.asarray(model.kin_dyn_parameters.level_mask)
n_levels = level_nodes.shape[0]
# Compute the base transform.
W_H_B = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3(wxyz=W_Q_B),
translation=W_p_B,
)
# Compute 6D transforms of the base velocity.
W_X_B = W_H_B.adjoint()
B_X_W = W_H_B.inverse().adjoint()
# Extract the parent-to-child adjoints of the joints.
i_X_λi = jnp.asarray(joint_transforms)
# Extract the joint motion subspaces.
S = model.kin_dyn_parameters.motion_subspaces
n = model.number_of_links()
# Parent array with root self-loop.
# Note: λ(0) is set to 0 to enable root self-referencing.
ptr0 = jnp.asarray(λ).at[0].set(0)
# Number of pointer-jumping rounds.
n_rounds = max(1, math.ceil(math.log2(max(n_levels, 2))))
# ======
# Pass 1
# ======
# Two coupled affine recurrences propagated via pointer jumping:
# v_i = i_X_λi[i] @ v_parent + vJ_i
# T_i = i_X_λi[i] @ T_parent
#
# Associative operator on (A, b, T):
# compose(parent, child) = (A @ A_p, A @ b_p + b, A @ T_p)
# Local transforms and joint velocities.
ṡ_col = jnp.atleast_1d(ṡ).reshape(-1, 1) # (n_joints, 1)
ṡ_padded = jnp.concatenate([jnp.zeros((1, 1)), ṡ_col]) # (n, 1)
vJ = S * ṡ_padded[:, :, None] # (n, 6, 1)
# Initialize pointer-jumping state for each node.
A = i_X_λi.copy() # (n, 6, 6)
b = vJ.copy() # (n, 6, 1)
T = i_X_λi.copy() # (n, 6, 6)
# Root initial values.
if model.floating_base():
v_0 = B_X_W @ W_v_WB
A = A.at[0].set(jnp.eye(6))
b = b.at[0].set(v_0)
T = T.at[0].set(jnp.eye(6))
else:
A = A.at[0].set(jnp.eye(6))
b = b.at[0].set(jnp.zeros((6, 1)))
T = T.at[0].set(jnp.eye(6))
ptr = ptr0.copy()
done = jnp.arange(n) == 0
def _pass1_jump(carry, _):
A, b, T, ptr, done = carry
need = ~done
A_par = A[ptr]
b_par = b[ptr]
T_par = T[ptr]
# Associative compose.
A_new = jnp.where(need[:, None, None], A @ A_par, A)
b_new = jnp.where(need[:, None, None], A @ b_par + b, b)
T_new = jnp.where(need[:, None, None], A @ T_par, T)
ptr_new = jnp.where(need, ptr[ptr], ptr)
done_new = done | done[ptr]
return (A_new, b_new, T_new, ptr_new, done_new), None
(_, v, i_X_0, _, _), _ = (
jax.lax.scan(
f=_pass1_jump,
init=(A, b, T, ptr, done),
xs=jnp.arange(n_rounds),
)
if n > 1
else ((A, b, T, ptr, done), None)
)
# v now contains the 6D body velocity of every link.
# i_X_0 contains the body-to-base transform for every link.
# Compute c, MA, pA for all nodes in parallel.
def _init_node(node_i):
vJ_i = S[node_i] * ṡ_padded[node_i]
c_i = Cross.vx(v[node_i]) @ vJ_i
MA_i = M[node_i]
i_Xf_W = Adjoint.inverse(i_X_0[node_i] @ B_X_W).T
pA_i = Cross.vx_star(v[node_i]) @ M[node_i] @ v[node_i] - i_Xf_W @ jnp.vstack(
W_f[node_i]
)
return c_i, MA_i, pA_i
c, MA, pA = jax.vmap(_init_node)(jnp.arange(n))
# Override base MA and pA if floating base.
if model.floating_base():
MA = MA.at[0].set(M[0])
pA_0 = Cross.vx_star(v[0]) @ M[0] @ v[0] - W_X_B.T @ jnp.vstack(W_f[0])
pA = pA.at[0].set(pA_0)
# ======
# Pass 2
# ======
# The Schur complement and multi-child scatter-add make this pass
# non-associative, so it remains level-parallel.
U = jnp.zeros_like(S)
d = jnp.ones(shape=(n, 1)) # Ones to avoid NaN for the base node.
u = jnp.zeros(shape=(n, 1))
def _masked_scatter_add(arr, indices, values, m):
"""Add values[j] to arr[indices[j]] only where m[j] is True."""
mask = jnp.reshape(m, m.shape + (1,) * (values.ndim - 1))
masked_values = jnp.where(mask, values, jnp.zeros_like(values))
return arr.at[indices].add(masked_values)
def _pass2_level(carry, level_idx):
U, d, u, MA, pA = carry
actual_level = n_levels - 1 - level_idx
nodes = level_nodes[actual_level]
mask = level_mask[actual_level]
def _process_node_pass2(node_i):
# Clamp index to avoid out-of-bounds for padded entries.
ii = jnp.maximum(node_i, 1) - 1
parent = λ[node_i]
U_i = MA[node_i] @ S[node_i]
d_i = (S[node_i].T @ U_i).squeeze()
u_i = (τ[ii] - S[node_i].T @ pA[node_i]).squeeze()
Ma_i = MA[node_i] - U_i / d_i @ U_i.T
pa_i = pA[node_i] + Ma_i @ c[node_i] + U_i * (u_i / d_i)
Ma_parent = i_X_λi[node_i].T @ Ma_i @ i_X_λi[node_i]
pa_parent = i_X_λi[node_i].T @ pa_i
return U_i, d_i, u_i, Ma_parent, pa_parent, parent
U_lev, d_lev, u_lev, Ma_par, pa_par, parents = jax.vmap(_process_node_pass2)(
nodes
)
mask_6x1 = mask[:, None, None]
mask_1 = mask[:, None]
U = carry[0].at[nodes].set(jnp.where(mask_6x1, U_lev, carry[0][nodes]))
d = carry[1].at[nodes].set(jnp.where(mask_1, d_lev[:, None], carry[1][nodes]))
u = carry[2].at[nodes].set(jnp.where(mask_1, u_lev[:, None], carry[2][nodes]))
should_propagate = jnp.where(
model.floating_base(),
mask,
jnp.logical_and(mask, parents != 0),
)
MA = _masked_scatter_add(carry[3], parents, Ma_par, should_propagate)
pA = _masked_scatter_add(carry[4], parents, pa_par, should_propagate)
return (U, d, u, MA, pA), None
n_backward_levels = n_levels - 1
(U, d, u, MA, pA), _ = (
jax.lax.scan(
f=_pass2_level,
init=(U, d, u, MA, pA),
xs=jnp.arange(n_backward_levels),
)
if n_backward_levels > 0
else ((U, d, u, MA, pA), None)
)
# ======
# Pass 3
# ======
# The acceleration recurrence is an affine recurrence:
# a_i = P_i @ i_X_λi[i] @ a_parent + P_i @ c_i + S_i * u_i / d_i
# where P_i = I - S_i @ U_i^T / d_i is the 6x6 projection matrix.
if model.floating_base():
a0 = jnp.linalg.solve(-MA[0], pA[0])
else:
a0 = -B_X_W @ W_g
# Pre-compute the affine recurrence coefficients for all nodes.
def _init_pass3(node_i):
P_i = jnp.eye(6) - S[node_i] @ U[node_i].T / d[node_i]
A_i = P_i @ i_X_λi[node_i]
b_i = P_i @ c[node_i] + S[node_i] * (u[node_i] / d[node_i])
return A_i, b_i
A, b = jax.vmap(_init_pass3)(jnp.arange(n))
# Root acceleration is known.
A = A.at[0].set(jnp.eye(6))
b = b.at[0].set(a0)
# Pointer jumping for the affine recurrence.
ptr = ptr0.copy()
done = jnp.arange(n) == 0
def _pass3_jump(carry, _):
A, b, ptr, done = carry
need = ~done
A_par = A[ptr]
b_par = b[ptr]
# Associative compose.
A_new = jnp.where(need[:, None, None], A @ A_par, A)
b_new = jnp.where(need[:, None, None], A @ b_par + b, b)
ptr_new = jnp.where(need, ptr[ptr], ptr)
done_new = done | done[ptr]
return (A_new, b_new, ptr_new, done_new), None
(_, a, _, _), _ = (
jax.lax.scan(
f=_pass3_jump,
init=(A, b, ptr, done),
xs=jnp.arange(n_rounds),
)
if n > 1
else ((A, b, ptr, done), None)
)
# Recover joint accelerations: s̈_i = (u_i - U_i^T @ a_before_i) / d_i
# where a_before_i = i_X_λi[i] @ a_parent + c_i.
a_λi = a[ptr0]
a_before = i_X_λi @ a_λi + c
Ut_a = (U.transpose(0, 2, 1) @ a_before).squeeze(-1) # (n, 1)
s̈ = (u - Ut_a) / d # (n, 1)
# ==============
# Adjust outputs
# ==============
if model.floating_base():
B_a_WB = a[0]
W_a_WB = W_X_B @ B_a_WB + W_g
else:
W_a_WB = jnp.zeros(6)
# Joint accelerations: skip base index, take indices 1..n-1.
s̈_out = s̈[1:]
return W_a_WB.squeeze(), jnp.atleast_1d(s̈_out.squeeze())
================================================
FILE: src/jaxsim/rbda/actuation/__init__.py
================================================
from .common import ActuationParams
================================================
FILE: src/jaxsim/rbda/actuation/common.py
================================================
import dataclasses
import jax_dataclasses
from jax_dataclasses import Static
import jaxsim.typing as jtp
from jaxsim.utils import JaxsimDataclass
@jax_dataclasses.pytree_dataclass
class ActuationParams(JaxsimDataclass):
"""
Parameters class for the actuation model.
"""
torque_max: jtp.Float = dataclasses.field(default=3000.0) # (Nm)
omega_th: jtp.Float = dataclasses.field(default=30.0) # (rad/s)
omega_max: jtp.Float = dataclasses.field(default=100.0) # (rad/s)
enable_friction: Static[bool] = dataclasses.field(default=True)
================================================
FILE: src/jaxsim/rbda/collidable_points.py
================================================
import jax
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import Skew
def collidable_points_pos_vel(
model: js.model.JaxSimModel,
*,
link_transforms: jtp.Matrix,
link_velocities: jtp.Matrix,
) -> tuple[jtp.Matrix, jtp.Matrix]:
"""
Compute the position and linear velocity of the enabled collidable points in the world frame.
Args:
model: The model to consider.
link_transforms: The transforms from the world frame to each link.
link_velocities: The linear and angular velocities of each link.
Returns:
A tuple containing the position and linear velocity of the enabled collidable points.
"""
# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)
parent_link_idx_of_enabled_collidable_points = jnp.array(
model.kin_dyn_parameters.contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]
L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
indices_of_enabled_collidable_points
]
if len(indices_of_enabled_collidable_points) == 0:
return jnp.array(0).astype(float), jnp.empty(0).astype(float)
def process_point_kinematics(
Li_p_C: jtp.Vector, parent_body: jtp.Int
) -> tuple[jtp.Vector, jtp.Vector]:
# Compute the position of the collidable point.
W_p_Ci = (link_transforms[parent_body] @ jnp.hstack([Li_p_C, 1]))[0:3]
# Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}.
CW_vl_WCi = (
jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()])
@ link_velocities[parent_body].squeeze()
)
return W_p_Ci, CW_vl_WCi
# Process all the collidable points in parallel.
W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)(
L_p_Ci,
parent_link_idx_of_enabled_collidable_points,
)
return W_p_Ci, CW_vl_WC
================================================
FILE: src/jaxsim/rbda/contacts/__init__.py
================================================
from . import relaxed_rigid, rigid, soft
from .common import ContactModel, ContactsParams
from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams
from .rigid import RigidContacts, RigidContactsParams
from .soft import SoftContacts, SoftContactsParams
ContactParamsTypes = (
SoftContactsParams | RigidContactsParams | RelaxedRigidContactsParams
)
================================================
FILE: src/jaxsim/rbda/contacts/common.py
================================================
from __future__ import annotations
import abc
import functools
import jax
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.terrain
import jaxsim.typing as jtp
from jaxsim.math import STANDARD_GRAVITY
from jaxsim.utils import JaxsimDataclass
try:
from typing import Self
except ImportError:
from typing_extensions import Self
MAX_STIFFNESS = 1e6
MAX_DAMPING = 1e4
@functools.partial(jax.jit, static_argnames=("terrain",))
def compute_penetration_data(
p: jtp.VectorLike,
v: jtp.VectorLike,
terrain: jaxsim.terrain.Terrain,
) -> tuple[jtp.Float, jtp.Float, jtp.Vector]:
"""
Compute the penetration data (depth, rate, and terrain normal) of a collidable point.
Args:
p: The position of the collidable point.
v:
The linear velocity of the point (linear component of the mixed 6D velocity
of the implicit frame `C = (W_p_C, [W])` associated to the point).
terrain: The considered terrain.
Returns:
A tuple containing the penetration depth, the penetration velocity,
and the considered terrain normal.
"""
# Pre-process the position and the linear velocity of the collidable point.
W_ṗ_C = jnp.array(v).squeeze()
px, py, pz = jnp.array(p).squeeze()
# Compute the terrain normal and the contact depth.
n̂ = terrain.normal(x=px, y=py).squeeze()
h = jnp.array([0, 0, terrain.height(x=px, y=py) - pz])
# Compute the penetration depth normal to the terrain.
δ = jnp.maximum(0.0, jnp.dot(h, n̂))
# Compute the penetration normal velocity.
δ_dot = -jnp.dot(W_ṗ_C, n̂)
# Enforce the penetration rate to be zero when the penetration depth is zero.
δ_dot = jnp.where(δ > 0, δ_dot, 0.0)
return δ, δ_dot, n̂
class ContactsParams(JaxsimDataclass):
"""
Abstract class representing the parameters of a contact model.
Note:
This class is supposed to store only the tunable parameters of the contact
model, i.e. all those parameters that can be changed during runtime.
If the contact model has also static parameters, they should be stored
in the corresponding `ContactModel` class.
"""
@classmethod
@abc.abstractmethod
def build(cls: type[Self], **kwargs) -> Self:
"""
Create a `ContactsParams` instance with specified parameters.
Returns:
The `ContactsParams` instance.
"""
pass
def build_default_from_jaxsim_model(
self: type[Self],
model: js.model.JaxSimModel,
*,
stiffness: jtp.FloatLike | None = None,
damping: jtp.FloatLike | None = None,
standard_gravity: jtp.FloatLike = STANDARD_GRAVITY,
static_friction_coefficient: jtp.FloatLike = 0.5,
max_penetration: jtp.FloatLike = 0.001,
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
damping_ratio: jtp.FloatLike = 1.0,
p: jtp.FloatLike = 0.5,
q: jtp.FloatLike = 0.5,
**kwargs,
) -> Self:
"""
Create a `ContactsParams` instance with default parameters.
Args:
model: The robot model considered by the contact model.
stiffness: The stiffness of the contact model.
damping: The damping of the contact model.
standard_gravity: The standard gravity acceleration.
static_friction_coefficient: The static friction coefficient.
max_penetration: The maximum penetration depth.
number_of_active_collidable_points_steady_state:
The number of active collidable points in steady state.
damping_ratio: The damping ratio.
p: The first parameter of the contact model.
q: The second parameter of the contact model.
**kwargs: Optional additional arguments.
Returns:
The `ContactsParams` instance.
Note:
The `stiffness` is intended as the terrain stiffness in the Soft Contacts model,
while it is the Baumgarte stabilization stiffness in the Rigid Contacts model.
The `damping` is intended as the terrain damping in the Soft Contacts model,
while it is the Baumgarte stabilization damping in the Rigid Contacts model.
The `damping_ratio` parameter allows to operate on the following conditions:
- ξ > 1.0: over-damped
- ξ = 1.0: critically damped
- ξ < 1.0: under-damped
"""
# Use symbols for input parameters.
ξ = damping_ratio
δ_max = max_penetration
μc = static_friction_coefficient
nc = number_of_active_collidable_points_steady_state
# Compute the total mass of the model.
m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum()
# Compute the stiffness to get the desired steady-state penetration.
# Note that this is dependent on the non-linear exponent used in
# the damping term of the Hunt/Crossley model.
if stiffness is None:
# Compute the average support force on each collidable point.
f_average = m * standard_gravity / nc
stiffness = f_average / jnp.power(δ_max, 1 + p)
stiffness = jnp.clip(stiffness, 0, MAX_STIFFNESS)
# Compute the damping using the damping ratio.
critical_damping = 2 * jnp.sqrt(stiffness * m)
if damping is None:
damping = ξ * critical_damping
damping = jnp.clip(damping, 0, MAX_DAMPING)
return self.build(
K=stiffness,
D=damping,
mu=μc,
p=p,
q=q,
**kwargs,
)
@abc.abstractmethod
def valid(self, **kwargs) -> jtp.BoolLike:
"""
Check if the parameters are valid.
Returns:
True if the parameters are valid, False otherwise.
"""
pass
class ContactModel(JaxsimDataclass):
"""
Abstract class representing a contact model.
"""
@classmethod
@abc.abstractmethod
def build(
cls: type[Self],
**kwargs,
) -> Self:
"""
Create a `ContactModel` instance with specified parameters.
Returns:
The `ContactModel` instance.
"""
pass
@abc.abstractmethod
def compute_contact_forces(
self,
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
**kwargs,
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
"""
Compute the contact forces.
Args:
model: The robot model considered by the contact model.
data: The data of the considered model.
**kwargs: Optional additional arguments, specific to the contact model.
Returns:
A tuple containing as first element the computed 6D contact force applied to
the contact points and expressed in the world frame, and as second element
a dictionary of optional additional information.
"""
pass
@classmethod
def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]:
"""
Build zero state variables of the contact model.
Args:
model: The robot model considered by the contact model.
Note:
There are contact models that require to extend the state vector of the
integrated ODE system with additional variables. Our integrators are
capable of operating on a generic state, as long as it is a PyTree.
This method builds the zero state variables of the contact model as a
dictionary of JAX arrays.
Returns:
A dictionary storing the zero state variables of the contact model.
"""
return {}
@property
def _parameters_class(self) -> type[ContactsParams]:
"""
Return the class of the contact parameters.
Returns:
The class of the contact parameters.
"""
import importlib
return getattr(
importlib.import_module("jaxsim.rbda.contacts"),
(
self.__name__ + "Params"
if isinstance(self, type)
else self.__class__.__name__ + "Params"
),
)
@abc.abstractmethod
def update_contact_state(
self: type[Self], old_contact_state: dict[str, jtp.Array]
) -> dict[str, jtp.Array]:
"""
Update the contact state.
Args:
old_contact_state: The old contact state.
Returns:
The updated contact state.
"""
@abc.abstractmethod
def update_velocity_after_impact(
self: type[Self], model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> js.data.JaxSimModelData:
"""
Update the velocity after an impact.
Args:
model: The robot model considered by the contact model.
data: The data of the considered model.
Returns:
The updated data of the considered model.
"""
================================================
FILE: src/jaxsim/rbda/contacts/relaxed_rigid.py
================================================
from __future__ import annotations
import dataclasses
from collections.abc import Callable
from typing import Any
import jax
import jax.numpy as jnp
import jax_dataclasses
import optax
from optax.tree_utils import tree_get
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
from . import common, soft
try:
from typing import Self
except ImportError:
from typing_extensions import Self
try:
from optax.tree_utils import tree_norm
except ImportError:
from optax.tree_utils import tree_l2_norm as tree_norm
@jax_dataclasses.pytree_dataclass
class RelaxedRigidContactsParams(common.ContactsParams):
"""Parameters of the relaxed rigid contacts model."""
# Time constant
time_constant: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(0.02, dtype=float)
)
# Adimensional damping coefficient
damping_coefficient: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(1.0, dtype=float)
)
# Minimum impedance
d_min: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(0.9, dtype=float)
)
# Maximum impedance
d_max: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(0.95, dtype=float)
)
# Width
width: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(0.001, dtype=float)
)
# Midpoint
midpoint: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(0.5, dtype=float)
)
# Power exponent
power: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(2.0, dtype=float)
)
# Stiffness
K: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(0.0, dtype=float)
)
# Damping
D: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(0.0, dtype=float)
)
# Friction coefficient
mu: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(0.005, dtype=float)
)
def __hash__(self) -> int:
from jaxsim.utils.wrappers import HashedNumpyArray
return hash(
(
HashedNumpyArray(self.time_constant),
HashedNumpyArray(self.damping_coefficient),
HashedNumpyArray(self.d_min),
HashedNumpyArray(self.d_max),
HashedNumpyArray(self.width),
HashedNumpyArray(self.midpoint),
HashedNumpyArray(self.power),
HashedNumpyArray(self.K),
HashedNumpyArray(self.D),
HashedNumpyArray(self.mu),
)
)
def __eq__(self, other: RelaxedRigidContactsParams) -> bool:
if not isinstance(other, RelaxedRigidContactsParams):
return False
return hash(self) == hash(other)
@classmethod
def build(
cls: type[Self],
*,
time_constant: jtp.FloatLike | None = None,
damping_coefficient: jtp.FloatLike | None = None,
d_min: jtp.FloatLike | None = None,
d_max: jtp.FloatLike | None = None,
width: jtp.FloatLike | None = None,
midpoint: jtp.FloatLike | None = None,
power: jtp.FloatLike | None = None,
K: jtp.FloatLike | None = None,
D: jtp.FloatLike | None = None,
mu: jtp.FloatLike | None = None,
**kwargs,
) -> Self:
"""Create a `RelaxedRigidContactsParams` instance."""
def default(name: str):
return cls.__dataclass_fields__[name].default_factory()
return cls(
time_constant=jnp.array(
(
time_constant
if time_constant is not None
else default("time_constant")
),
dtype=float,
),
damping_coefficient=jnp.array(
(
damping_coefficient
if damping_coefficient is not None
else default("damping_coefficient")
),
dtype=float,
),
d_min=jnp.array(
d_min if d_min is not None else default("d_min"), dtype=float
),
d_max=jnp.array(
d_max if d_max is not None else default("d_max"), dtype=float
),
width=jnp.array(
width if width is not None else default("width"), dtype=float
),
midpoint=jnp.array(
midpoint if midpoint is not None else default("midpoint"), dtype=float
),
power=jnp.array(
power if power is not None else default("power"), dtype=float
),
K=jnp.array(
K if K is not None else default("K"),
dtype=float,
),
D=jnp.array(D if D is not None else default("D"), dtype=float),
mu=jnp.array(mu if mu is not None else default("mu"), dtype=float),
)
def valid(self) -> jtp.BoolLike:
"""Check if the parameters are valid."""
return bool(
jnp.all(self.time_constant >= 0.0)
and jnp.all(self.damping_coefficient > 0.0)
and jnp.all(self.d_min >= 0.0)
and jnp.all(self.d_max <= 1.0)
and jnp.all(self.d_min <= self.d_max)
and jnp.all(self.width >= 0.0)
and jnp.all(self.midpoint >= 0.0)
and jnp.all(self.power >= 0.0)
and jnp.all(self.mu >= 0.0)
)
@jax_dataclasses.pytree_dataclass
class RelaxedRigidContacts(common.ContactModel):
"""Relaxed rigid contacts model."""
_solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field(
default=("tol", "maxiter", "memory_size", "scale_init_precond"), kw_only=True
)
_solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field(
default=(1e-6, 50, 10, False), kw_only=True
)
@property
def solver_options(self) -> dict[str, Any]:
"""Get the solver options."""
return dict(
zip(
self._solver_options_keys,
self._solver_options_values,
strict=True,
)
)
@classmethod
def build(
cls: type[Self],
solver_options: dict[str, Any] | None = None,
**kwargs,
) -> Self:
"""
Create a `RelaxedRigidContacts` instance with specified parameters.
Args:
solver_options: The options to pass to the L-BFGS solver.
**kwargs: The parameters of the relaxed rigid contacts model.
Returns:
The `RelaxedRigidContacts` instance.
"""
# Get the default solver options.
default_solver_options = dict(
zip(cls._solver_options_keys, cls._solver_options_values, strict=True)
)
# Create the solver options to set by combining the default solver options
# with the user-provided solver options.
solver_options = default_solver_options | (
solver_options if solver_options is not None else {}
)
# Make sure that the solver options are hashable.
# We need to check this because the solver options are static.
try:
hash(tuple(solver_options.values()))
except TypeError as exc:
raise ValueError(
"The values of the solver options must be hashable."
) from exc
return cls(
_solver_options_keys=tuple(solver_options.keys()),
_solver_options_values=tuple(solver_options.values()),
**kwargs,
)
def update_contact_state(
self: type[Self], old_contact_state: dict[str, jtp.Array]
) -> dict[str, jtp.Array]:
"""
Update the contact state.
Args:
old_contact_state: The old contact state.
Returns:
The updated contact state.
"""
return {}
def update_velocity_after_impact(
self: type[Self], model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> js.data.JaxSimModelData:
"""
Update the velocity after an impact.
Args:
model: The robot model considered by the contact model.
data: The data of the considered model.
Returns:
The updated data of the considered model.
"""
return data
@jax.jit
def compute_contact_forces(
self,
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
"""
Compute the contact forces.
Args:
model: The model to consider.
data: The data of the considered model.
link_forces:
Optional `(n_links, 6)` matrix of external forces acting on the links,
expressed in the same representation of data.
joint_force_references:
Optional `(n_joints,)` vector of joint forces.
Returns:
A tuple containing as first element the computed contact forces in inertial representation.
"""
link_forces = jnp.atleast_2d(
jnp.array(link_forces, dtype=float).squeeze()
if link_forces is not None
else jnp.zeros((model.number_of_links(), 6))
)
joint_force_references = jnp.atleast_1d(
jnp.array(joint_force_references, dtype=float).squeeze()
if joint_force_references is not None
else jnp.zeros(model.number_of_joints())
)
references = js.references.JaxSimModelReferences.build(
model=model,
data=data,
velocity_representation=data.velocity_representation,
link_forces=link_forces,
joint_force_references=joint_force_references,
)
# Compute the position and linear velocities (mixed representation) of
# all collidable points belonging to the robot.
position, velocity = js.contact.collidable_point_kinematics(
model=model, data=data
)
# Compute the penetration depth and velocity of the collidable points.
# Note that this function considers the penetration in the normal direction.
δ, _, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))(
position, velocity, model.terrain
)
# Compute the position in the constraint frame.
position_constraint = jax.vmap(lambda δ, n̂: -δ * n̂)(δ, n̂)
# Compute the regularization terms.
a_ref, r, *_ = self._regularizers(
model=model,
position_constraint=position_constraint,
velocity_constraint=velocity,
parameters=model.contact_params,
)
# Compute the transforms of the implicit frames corresponding to the
# collidable points.
W_H_C = js.contact.transforms(model=model, data=data)
with (
data.switch_velocity_representation(VelRepr.Mixed),
references.switch_velocity_representation(VelRepr.Mixed),
):
BW_ν = data.generalized_velocity
BW_ν̇_free = jnp.hstack(
js.model.forward_dynamics_aba(
model=model,
data=data,
link_forces=references.link_forces(model=model, data=data),
joint_forces=references.joint_force_references(model=model),
)
)
M_inv = js.model.free_floating_mass_matrix_inverse(model=model, data=data)
# Compute the linear part of the Jacobian of the collidable points
Jl_WC = jnp.vstack(
jax.vmap(lambda J, δ: J * (δ > 0))(
js.contact.jacobian(model=model, data=data)[:, :3, :], δ
)
)
# Compute the linear part of the Jacobian derivative of the collidable points
J̇l_WC = jnp.vstack(
jax.vmap(lambda J̇, δ: J̇ * (δ > 0))(
js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ
),
)
# Compute the Delassus matrix directly using J and J̇.
G_contacts = Jl_WC @ M_inv @ Jl_WC.T
# Compute the free mixed linear acceleration of the collidable points.
CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇l_WC @ BW_ν
# Calculate quantities for the linear optimization problem.
R = jnp.diag(r)
A = G_contacts + R
b = CW_al_free_WC - a_ref
# Create the objective function to minimize as a lambda computing the cost
# from the optimized variables x.
objective = lambda x, A, b: jnp.sum(jnp.square(A @ x + b))
# ========================================
# Helper function to run the L-BFGS solver
# ========================================
def run_optimization(
init_params: jtp.Vector,
fun: Callable,
opt: optax.GradientTransformationExtraArgs,
maxiter: int,
tol: float,
) -> tuple[jtp.Vector, optax.OptState]:
# Get the function to compute the loss and the gradient w.r.t. its inputs.
value_and_grad_fn = optax.value_and_grad_from_state(fun)
# Initialize the carry of the following loop.
OptimizationCarry = tuple[jtp.Vector, optax.OptState]
init_carry: OptimizationCarry = (init_params, opt.init(params=init_params))
def step(carry: OptimizationCarry) -> OptimizationCarry:
params, state = carry
value, grad = value_and_grad_fn(
params,
state=state,
A=A,
b=b,
)
updates, state = opt.update(
updates=grad,
state=state,
params=params,
value=value,
grad=grad,
value_fn=fun,
A=A,
b=b,
)
params = optax.apply_updates(params, updates)
return params, state
# TODO: maybe fix the number of iterations and switch to scan?
def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool:
_, state = carry
iter_num = tree_get(state, "count")
grad = tree_get(state, "grad")
err = tree_norm(grad)
return (iter_num == 0) | ((iter_num < maxiter) & (err >= tol))
final_params, final_state = jax.lax.while_loop(
continuing_criterion, step, init_carry
)
return final_params, final_state
# ======================================
# Compute the contact forces with L-BFGS
# ======================================
# Initialize the optimized forces with a linear Hunt/Crossley model.
init_params = jax.vmap(
lambda p, v: soft.SoftContacts.hunt_crossley_contact_model(
position=p,
velocity=v,
terrain=model.terrain,
K=1e6,
D=2e3,
p=0.5,
q=0.5,
# No tangential initial forces.
mu=0.0,
tangential_deformation=jnp.zeros(3),
)[0]
)(position, velocity).flatten()
# Get the solver options.
solver_options = self.solver_options
# Extract the options corresponding to the convergence criteria.
# All the remaining options are passed to the solver.
tol = solver_options.pop("tol")
maxiter = solver_options.pop("maxiter")
solve_fn = lambda *_: run_optimization(
init_params=init_params,
fun=objective,
opt=optax.lbfgs(**solver_options),
tol=tol,
maxiter=maxiter,
)
# Compute the 3D linear force in C[W] frame.
solution, _ = jax.lax.custom_linear_solve(
lambda x: A @ x,
-b,
solve=solve_fn,
symmetric=True,
has_aux=True,
)
# Reshape the optimized solution to be a matrix of 3D contact forces.
CW_fl_C = solution.reshape(-1, 3)
# Convert the contact forces from mixed to inertial-fixed representation.
W_f_C = jax.vmap(
lambda CW_fl_C, W_H_C: (
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
array=jnp.zeros(6).at[0:3].set(CW_fl_C),
transform=W_H_C,
other_representation=VelRepr.Mixed,
is_force=True,
)
),
)(CW_fl_C, W_H_C)
return W_f_C, {}
@staticmethod
def _regularizers(
model: js.model.JaxSimModel,
position_constraint: jtp.Vector,
velocity_constraint: jtp.Vector,
parameters: RelaxedRigidContactsParams,
) -> tuple:
"""
Compute the contact jacobian and the reference acceleration.
Args:
model: The jaxsim model.
position_constraint: The position of the collidable points in the constraint frame.
velocity_constraint: The velocity of the collidable points in the constraint frame.
parameters: The parameters of the relaxed rigid contacts model.
Returns:
A tuple containing the reference acceleration, the regularization matrix,
the stiffness, and the damping.
"""
# Extract the parameters of the contact model.
Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ = (
getattr(parameters, field)
for field in (
"time_constant",
"damping_coefficient",
"d_min",
"d_max",
"width",
"midpoint",
"power",
"K",
"D",
"mu",
)
)
# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)
parent_link_idx_of_enabled_collidable_points = jnp.array(
model.kin_dyn_parameters.contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]
# Compute the 6D inertia matrices of all links.
M_L = js.model.link_spatial_inertia_matrices(model=model)
def imp_aref(
pos: jtp.Vector,
vel: jtp.Vector,
) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector]:
"""
Calculate impedance and offset acceleration in constraint frame.
Args:
pos: position in constraint frame.
vel: velocity in constraint frame.
Returns:
ξ: computed impedance
a_ref: offset acceleration in constraint frame
K: computed stiffness
D: computed damping
"""
imp_x = jnp.abs(pos) / width
imp_a = (1.0 / jnp.power(mid, p - 1)) * jnp.power(imp_x, p)
imp_b = 1 - (1.0 / jnp.power(1 - mid, p - 1)) * jnp.power(1 - imp_x, p)
imp_y = jnp.where(imp_x < mid, imp_a, imp_b)
# Compute the impedance.
ξ = ξ_min + imp_y * (ξ_max - ξ_min)
ξ = jnp.clip(ξ, ξ_min, ξ_max)
ξ = jnp.where(imp_x > 1.0, ξ_max, ξ)
# Compute the spring and damper parameters during runtime from the
# impedance and other contact parameters.
K = 1 / (ξ_max * Ω * ζ) ** 2
D = 2 / (ξ_max * Ω)
# If the user specifies K and D and they are negative, the computed `a_ref`
# becomes something more similar to a classic Baumgarte regularization.
K = jnp.where(K < 0, -K / ξ_max**2, K)
D = jnp.where(D < 0, -D / ξ_max, D)
# Compute the reference acceleration.
a_ref = -(D * vel + K * ξ * pos)
return ξ, a_ref, K, D
def compute_row(
*,
link_idx: jtp.Int,
pos: jtp.Vector,
vel: jtp.Vector,
) -> tuple[jtp.Vector, jtp.Matrix, jtp.Vector, jtp.Vector]:
# Compute the reference acceleration.
ξ, a_ref, K, D = imp_aref(pos=pos, vel=vel)
# Compute the regularization term.
R = (
(2 * μ**2 * (1 - ξ) / (ξ + 1e-12))
* (1 + μ**2)
@ jnp.linalg.inv(M_L[link_idx, :3, :3])
)
# Return the computed values, setting them to zero in case of no contact.
is_active = (pos.dot(pos) > 0).astype(float)
return jax.tree.map(
lambda x: jnp.atleast_1d(x) * is_active, (a_ref, R, K, D)
)
a_ref, R, K, D = jax.tree.map(
f=jnp.concatenate,
tree=(
*jax.vmap(compute_row)(
link_idx=parent_link_idx_of_enabled_collidable_points,
pos=position_constraint,
vel=velocity_constraint,
),
),
)
return a_ref, R, K, D
================================================
FILE: src/jaxsim/rbda/contacts/rigid.py
================================================
from __future__ import annotations
import dataclasses
from typing import Any
import jax
import jax.numpy as jnp
import jax_dataclasses
import qpax
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim import logging
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
from . import common
from .common import ContactModel, ContactsParams
try:
from typing import Self
except ImportError:
from typing_extensions import Self
@jax_dataclasses.pytree_dataclass
class RigidContactsParams(ContactsParams):
"""Parameters of the rigid contacts model."""
# Static friction coefficient
mu: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(0.5, dtype=float)
)
# Baumgarte proportional term
K: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(0.0, dtype=float)
)
# Baumgarte derivative term
D: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(0.0, dtype=float)
)
def __hash__(self) -> int:
from jaxsim.utils.wrappers import HashedNumpyArray
return hash(
(
HashedNumpyArray.hash_of_array(self.mu),
HashedNumpyArray.hash_of_array(self.K),
HashedNumpyArray.hash_of_array(self.D),
)
)
def __eq__(self, other: RigidContactsParams) -> bool:
if not isinstance(other, RigidContactsParams):
return False
return hash(self) == hash(other)
@classmethod
def build(
cls: type[Self],
*,
mu: jtp.FloatLike | None = None,
K: jtp.FloatLike | None = None,
D: jtp.FloatLike | None = None,
**kwargs,
) -> Self:
"""Create a `RigidContactParams` instance."""
return cls(
mu=jnp.array(
mu
if mu is not None
else cls.__dataclass_fields__["mu"].default_factory()
).astype(float),
K=jnp.array(
K if K is not None else cls.__dataclass_fields__["K"].default_factory()
).astype(float),
D=jnp.array(
D if D is not None else cls.__dataclass_fields__["D"].default_factory()
).astype(float),
)
def valid(self) -> jtp.BoolLike:
"""Check if the parameters are valid."""
return bool(
jnp.all(self.mu >= 0.0)
and jnp.all(self.K >= 0.0)
and jnp.all(self.D >= 0.0)
)
@jax_dataclasses.pytree_dataclass
class RigidContacts(ContactModel):
"""Rigid contacts model."""
regularization_delassus: jax_dataclasses.Static[float] = dataclasses.field(
default=1e-6, kw_only=True
)
_solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field(
default=("solver_tol",), kw_only=True
)
_solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field(
default=(1e-3,), kw_only=True
)
@property
def solver_options(self) -> dict[str, Any]:
"""Get the solver options as a dictionary."""
return dict(
zip(
self._solver_options_keys,
self._solver_options_values,
strict=True,
)
)
@classmethod
def build(
cls: type[Self],
regularization_delassus: jtp.FloatLike | None = None,
solver_options: dict[str, Any] | None = None,
**kwargs,
) -> Self:
"""
Create a `RigidContacts` instance with specified parameters.
Args:
regularization_delassus:
The regularization term to add to the diagonal of the Delassus matrix.
solver_options: The options to pass to the QP solver.
**kwargs: Extra arguments which are ignored.
Returns:
The `RigidContacts` instance.
"""
if kwargs:
logging.warning(msg=f"Ignoring extra arguments: {kwargs}")
# Get the default solver options.
default_solver_options = dict(
zip(cls._solver_options_keys, cls._solver_options_values, strict=True)
)
# Create the solver options to set by combining the default solver options
# with the user-provided solver options.
solver_options = default_solver_options | (
solver_options if solver_options is not None else {}
)
# Make sure that the solver options are hashable.
# We need to check this because the solver options are static.
try:
hash(tuple(solver_options.values()))
except TypeError as exc:
raise ValueError(
"The values of the solver options must be hashable."
) from exc
return cls(
regularization_delassus=float(
regularization_delassus
if regularization_delassus is not None
else cls.__dataclass_fields__["regularization_delassus"].default
),
_solver_options_keys=tuple(solver_options.keys()),
_solver_options_values=tuple(solver_options.values()),
**kwargs,
)
@staticmethod
def compute_impact_velocity(
inactive_collidable_points: jtp.ArrayLike,
M: jtp.MatrixLike,
J_WC: jtp.MatrixLike,
generalized_velocity: jtp.VectorLike,
) -> jtp.Vector:
"""
Return the new velocity of the system after a potential impact.
Args:
inactive_collidable_points: The activation state of the collidable points.
M: The mass matrix of the system (in mixed representation).
J_WC: The Jacobian matrix of the collidable points (in mixed representation).
generalized_velocity: The generalized velocity of the system.
Note:
The mass matrix `M`, the Jacobian `J_WC`, and the generalized velocity `generalized_velocity`
must be expressed in the same velocity representation.
"""
# Compute system velocity after impact maintaining zero linear velocity of active points.
sl = jnp.s_[:, 0:3, :]
Jl_WC = J_WC[sl]
# Zero out the jacobian rows of inactive points.
Jl_WC = jnp.vstack(
jnp.where(
inactive_collidable_points[:, jnp.newaxis, jnp.newaxis],
jnp.zeros_like(Jl_WC),
Jl_WC,
)
)
A = jnp.vstack(
[
jnp.hstack([M, -Jl_WC.T]),
jnp.hstack([Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))]),
]
)
b = jnp.hstack([M @ generalized_velocity, jnp.zeros(Jl_WC.shape[0])])
BW_ν_post_impact = jnp.linalg.lstsq(A, b)[0]
return BW_ν_post_impact[0 : M.shape[0]]
@jax.jit
@js.common.named_scope
def compute_contact_forces(
self,
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
"""
Compute the contact forces.
Args:
model: The model to consider.
data: The data of the considered model.
link_forces:
Optional `(n_links, 6)` matrix of external forces acting on the links,
expressed in the same representation of data.
joint_force_references:
Optional `(n_joints,)` vector of joint forces.
Returns:
A tuple containing as first element the computed contact forces.
"""
# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)
n_collidable_points = len(indices_of_enabled_collidable_points)
link_forces = jnp.atleast_2d(
jnp.array(link_forces, dtype=float).squeeze()
if link_forces is not None
else jnp.zeros((model.number_of_links(), 6))
)
joint_force_references = jnp.atleast_1d(
jnp.array(joint_force_references, dtype=float).squeeze()
if joint_force_references is not None
else jnp.zeros((model.number_of_joints(),))
)
# Build a references object to simplify converting link forces.
references = js.references.JaxSimModelReferences.build(
model=model,
data=data,
velocity_representation=data.velocity_representation,
link_forces=link_forces,
joint_force_references=joint_force_references,
)
# Compute the position and linear velocities (mixed representation) of
# all enabled collidable points belonging to the robot.
position, velocity = js.contact.collidable_point_kinematics(
model=model, data=data
)
# Compute the penetration depth and velocity of the collidable points.
# Note that this function considers the penetration in the normal direction.
δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))(
position, velocity, model.terrain
)
W_H_C = js.contact.transforms(model=model, data=data)
with (
references.switch_velocity_representation(VelRepr.Mixed),
data.switch_velocity_representation(VelRepr.Mixed),
):
# Compute kin-dyn quantities used in the contact model.
BW_ν = data.generalized_velocity
M_inv = js.model.free_floating_mass_matrix_inverse(model=model, data=data)
J_WC = js.contact.jacobian(model=model, data=data)
J̇_WC = js.contact.jacobian_derivative(model=model, data=data)
# Compute the generalized free acceleration.
BW_ν̇_free = jnp.hstack(
js.model.forward_dynamics_aba(
model=model,
data=data,
link_forces=references.link_forces(model=model, data=data),
joint_forces=references.joint_force_references(model=model),
)
)
# Compute the free linear acceleration of the collidable points.
# Since we use doubly-mixed jacobian, this corresponds to W_p̈_C.
free_contact_acc = _linear_acceleration_of_collidable_points(
BW_nu=BW_ν,
BW_nu_dot=BW_ν̇_free,
CW_J_WC_BW=J_WC,
CW_J_dot_WC_BW=J̇_WC,
).flatten()
# Compute stabilization term.
baumgarte_term = _compute_baumgarte_stabilization_term(
inactive_collidable_points=(δ <= 0),
δ=δ,
δ_dot=δ_dot,
n=n̂,
K=model.contact_params.K,
D=model.contact_params.D,
).flatten()
# Compute the Delassus matrix.
delassus_matrix = _delassus_matrix(M_inv=M_inv, J_WC=J_WC)
# Initialize regularization term of the Delassus matrix for
# better numerical conditioning.
Iε = self.regularization_delassus * jnp.eye(delassus_matrix.shape[0])
# Construct the quadratic cost function.
Q = delassus_matrix + Iε
q = free_contact_acc - baumgarte_term
# Construct the inequality constraints.
G = _compute_ineq_constraint_matrix(
inactive_collidable_points=(δ <= 0), mu=model.contact_params.mu
)
h_bounds = jnp.zeros(shape=(n_collidable_points * 6,))
# Construct the equality constraints.
A = jnp.zeros((0, 3 * n_collidable_points))
b = jnp.zeros((0,))
# Solve the following optimization problem with qpax:
#
# min_{x} 0.5 x⊤ Q x + q⊤ x
#
# s.t. A x = b
# G x ≤ h
#
# TODO: add possibility to notify if the QP problem did not converge.
solution, _, _, _, converged, _ = qpax.solve_qp( # noqa: RUF059
Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, **self.solver_options
)
# Reshape the optimized solution to be a matrix of 3D contact forces.
CW_fl_C = solution.reshape(-1, 3)
# Convert the contact forces from mixed to inertial-fixed representation.
W_f_C = jax.vmap(
lambda CW_fl_C, W_H_C: (
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
array=jnp.zeros(6).at[0:3].set(CW_fl_C),
transform=W_H_C,
other_representation=VelRepr.Mixed,
is_force=True,
)
),
)(CW_fl_C, W_H_C)
return W_f_C, {}
@jax.jit
@js.common.named_scope
def update_velocity_after_impact(
self: type[Self], model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> js.data.JaxSimModelData:
"""
Update the velocity after an impact.
Args:
model: The robot model considered by the contact model.
data: The data of the considered model.
Returns:
The updated data of the considered model.
"""
# Extract the indices corresponding to the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)
W_p_C = js.contact.collidable_point_positions(model, data)[
indices_of_enabled_collidable_points
]
# Compute the penetration depth of the collidable points.
δ, *_ = jax.vmap(
common.compute_penetration_data,
in_axes=(0, 0, None),
)(W_p_C, jnp.zeros_like(W_p_C), model.terrain)
with data.switch_velocity_representation(VelRepr.Mixed):
J_WC = js.contact.jacobian(model, data)[
indices_of_enabled_collidable_points
]
M = js.model.free_floating_mass_matrix(model, data)
BW_ν_pre_impact = data.generalized_velocity
# Compute the impact velocity.
# It may be discontinuous in case new contacts are made.
BW_ν_post_impact = RigidContacts.compute_impact_velocity(
generalized_velocity=BW_ν_pre_impact,
inactive_collidable_points=(δ <= 0),
M=M,
J_WC=J_WC,
)
BW_ν_post_impact_inertial = data.other_representation_to_inertial(
array=BW_ν_post_impact[0:6],
other_representation=VelRepr.Mixed,
transform=data._base_transform.at[0:3, 0:3].set(jnp.eye(3)),
is_force=False,
)
# Reset the generalized velocity.
data = dataclasses.replace(
data,
_base_linear_velocity=BW_ν_post_impact_inertial[0:3],
_base_angular_velocity=BW_ν_post_impact_inertial[3:6],
_joint_velocities=BW_ν_post_impact[6:],
)
return data
def update_contact_state(
self: type[Self], old_contact_state: dict[str, jtp.Array]
) -> dict[str, jtp.Array]:
"""
Update the contact state.
Args:
old_contact_state: The old contact state.
Returns:
The updated contact state.
"""
return {}
@staticmethod
def _delassus_matrix(
M_inv: jtp.MatrixLike,
J_WC: jtp.MatrixLike,
) -> jtp.Matrix:
sl = jnp.s_[:, 0:3, :]
J_WC_lin = jnp.vstack(J_WC[sl])
delassus_matrix = J_WC_lin @ M_inv @ J_WC_lin.T
return delassus_matrix
@jax.jit
@js.common.named_scope
def _compute_ineq_constraint_matrix(
inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike
) -> jtp.Matrix:
"""
Compute the inequality constraint matrix for a single collidable point.
Rows 0-3: enforce the friction pyramid constraint,
Row 4: last one is for the non negativity of the vertical force
Row 5: contact complementarity condition
"""
G_single_point = jnp.array(
[
[1, 0, -mu],
[0, 1, -mu],
[-1, 0, -mu],
[0, -1, -mu],
[0, 0, -1],
[0, 0, 0],
]
)
G = jnp.tile(G_single_point, (len(inactive_collidable_points), 1, 1))
G = G.at[:, 5, 2].set(inactive_collidable_points)
G = jax.scipy.linalg.block_diag(*G)
return G
@jax.jit
@js.common.named_scope
def _linear_acceleration_of_collidable_points(
BW_nu: jtp.ArrayLike,
BW_nu_dot: jtp.ArrayLike,
CW_J_WC_BW: jtp.MatrixLike,
CW_J_dot_WC_BW: jtp.MatrixLike,
) -> jtp.Matrix:
BW_ν = BW_nu
BW_ν̇ = BW_nu_dot
CW_J̇_WC_BW = CW_J_dot_WC_BW
# Compute the linear acceleration of the collidable points.
# Since we use doubly-mixed jacobians, this corresponds to W_p̈_C.
CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇
CW_a_WC = CW_a_WC.reshape(-1, 6)
return CW_a_WC[:, 0:3].squeeze()
@jax.jit
@js.common.named_scope
def _compute_baumgarte_stabilization_term(
inactive_collidable_points: jtp.ArrayLike,
δ: jtp.ArrayLike,
δ_dot: jtp.ArrayLike,
n: jtp.ArrayLike,
K: jtp.FloatLike,
D: jtp.FloatLike,
) -> jtp.Array:
return jnp.where(
inactive_collidable_points[:, jnp.newaxis],
jnp.zeros_like(n),
(K * δ + D * δ_dot)[:, jnp.newaxis] * n,
)
================================================
FILE: src/jaxsim/rbda/contacts/soft.py
================================================
from __future__ import annotations
import dataclasses
import functools
import jax
import jax.numpy as jnp
import jax_dataclasses
import jaxsim.api as js
import jaxsim.math
import jaxsim.typing as jtp
from jaxsim import logging
from jaxsim.terrain import Terrain
from . import common
try:
from typing import Self
except ImportError:
from typing_extensions import Self
@jax_dataclasses.pytree_dataclass
class SoftContactsParams(common.ContactsParams):
"""Parameters of the soft contacts model."""
K: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(1e6, dtype=float)
)
D: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(2000, dtype=float)
)
mu: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(0.5, dtype=float)
)
p: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(0.5, dtype=float)
)
q: jtp.Float = dataclasses.field(
default_factory=lambda: jnp.array(0.5, dtype=float)
)
def __hash__(self) -> int:
from jaxsim.utils.wrappers import HashedNumpyArray
return hash(
(
HashedNumpyArray.hash_of_array(self.K),
HashedNumpyArray.hash_of_array(self.D),
HashedNumpyArray.hash_of_array(self.mu),
HashedNumpyArray.hash_of_array(self.p),
HashedNumpyArray.hash_of_array(self.q),
)
)
def __eq__(self, other: SoftContactsParams) -> bool:
if not isinstance(other, SoftContactsParams):
return False
return hash(self) == hash(other)
@classmethod
def build(
cls: type[Self],
*,
K: jtp.FloatLike = 1e6,
D: jtp.FloatLike = 2_000,
mu: jtp.FloatLike = 0.5,
p: jtp.FloatLike = 0.5,
q: jtp.FloatLike = 0.5,
**kwargs,
) -> Self:
"""
Create a SoftContactsParams instance with specified parameters.
Args:
K: The stiffness parameter.
D: The damping parameter of the soft contacts model.
mu: The static friction coefficient.
p:
The exponent p corresponding to the damping-related non-linearity
of the Hunt/Crossley model.
q:
The exponent q corresponding to the spring-related non-linearity
of the Hunt/Crossley model
**kwargs: Additional parameters to pass to the contact model.
Returns:
A SoftContactsParams instance with the specified parameters.
"""
return SoftContactsParams(
K=jnp.array(K, dtype=float),
D=jnp.array(D, dtype=float),
mu=jnp.array(mu, dtype=float),
p=jnp.array(p, dtype=float),
q=jnp.array(q, dtype=float),
)
def valid(self) -> jtp.BoolLike:
"""
Check if the parameters are valid.
Returns:
`True` if the parameters are valid, `False` otherwise.
"""
return jnp.hstack(
[
self.K >= 0.0,
self.D >= 0.0,
self.mu >= 0.0,
self.p >= 0.0,
self.q >= 0.0,
]
).all()
@jax_dataclasses.pytree_dataclass
class SoftContacts(common.ContactModel):
"""Soft contacts model."""
@classmethod
def build(
cls: type[Self],
**kwargs,
) -> Self:
"""
Create a `SoftContacts` instance with specified parameters.
Args:
**kwargs: Additional parameters to pass to the contact model.
Returns:
The `SoftContacts` instance.
"""
if kwargs:
logging.warning(msg=f"Ignoring extra arguments: {kwargs}")
return cls(**kwargs)
@classmethod
def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]:
"""
Build zero state variables of the contact model.
"""
# Initialize the material deformation to zero.
tangential_deformation = jnp.zeros(
shape=(len(model.kin_dyn_parameters.contact_parameters.body), 3),
dtype=float,
)
return {"tangential_deformation": tangential_deformation}
def update_contact_state(
self: type[Self], old_contact_state: dict[str, jtp.Array]
) -> dict[str, jtp.Array]:
"""
Update the contact state.
Args:
old_contact_state: The old contact state.
Returns:
The updated contact state.
"""
return {"tangential_deformation": old_contact_state["m_dot"]}
def update_velocity_after_impact(
self: type[Self], model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> js.data.JaxSimModelData:
"""
Update the velocity after an impact.
Args:
model: The robot model considered by the contact model.
data: The data of the considered model.
Returns:
The updated data of the considered model.
"""
return data
@staticmethod
@functools.partial(jax.jit, static_argnames=("terrain",))
def hunt_crossley_contact_model(
position: jtp.VectorLike,
velocity: jtp.VectorLike,
tangential_deformation: jtp.VectorLike,
terrain: Terrain,
K: jtp.FloatLike,
D: jtp.FloatLike,
mu: jtp.FloatLike,
p: jtp.FloatLike = 0.5,
q: jtp.FloatLike = 0.5,
) -> tuple[jtp.Vector, jtp.Vector]:
"""
Compute the contact force using the Hunt/Crossley model.
Args:
position: The position of the collidable point.
velocity: The velocity of the collidable point.
tangential_deformation: The material deformation of the collidable point.
terrain: The terrain model.
K: The stiffness parameter.
D: The damping parameter of the soft contacts model.
mu: The static friction coefficient.
p:
The exponent p corresponding to the damping-related non-linearity
of the Hunt/Crossley model.
q:
The exponent q corresponding to the spring-related non-linearity
of the Hunt/Crossley model
Returns:
A tuple containing the computed contact force and the derivative of the
material deformation.
"""
# Convert the input vectors to arrays.
W_p_C = jnp.array(position, dtype=float).squeeze()
W_ṗ_C = jnp.array(velocity, dtype=float).squeeze()
m = jnp.array(tangential_deformation, dtype=float).squeeze()
# Use symbol for the static friction.
μ = mu
# Compute the penetration depth, its rate, and the considered terrain normal.
δ, δ̇, n̂ = common.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=terrain)
# There are few operations like computing the norm of a vector with zero length
# or computing the square root of zero that are problematic in an AD context.
# To avoid these issues, we introduce a small tolerance ε to their arguments
# and make sure that we do not check them against zero directly.
ε = jnp.finfo(float).eps
# Compute the powers of the penetration depth.
# Inject ε to address AD issues in differentiating the square root when
# p and q are fractional.
δp = jnp.power(δ + ε, p)
δq = jnp.power(δ + ε, q)
# ========================
# Compute the normal force
# ========================
# Non-linear spring-damper model (Hunt/Crossley model).
# This is the force magnitude along the direction normal to the terrain.
force_normal_mag = (K * δp) * δ + (D * δq) * δ̇
# Depending on the magnitude of δ̇, the normal force could be negative.
force_normal_mag = jnp.maximum(0.0, force_normal_mag)
# Compute the 3D linear force in C[W] frame.
f_normal = force_normal_mag * n̂
# ============================
# Compute the tangential force
# ============================
# Extract the tangential component of the velocity.
v_tangential = W_ṗ_C - jnp.dot(W_ṗ_C, n̂) * n̂
# Extract the normal and tangential components of the material deformation.
m_normal = jnp.dot(m, n̂) * n̂
m_tangential = m - jnp.dot(m, n̂) * n̂
# Compute the tangential force in the sticking case.
# Using the tangential component of the material deformation should not be
# necessary if the sticking-slipping transition occurs in a terrain area
# with a locally constant normal. However, this assumption is not true in
# general, especially for highly uneven terrains.
f_tangential = -((K * δp) * m_tangential + (D * δq) * v_tangential)
# Detect the contact type (sticking or slipping).
# Note that if there is no contact, sticking is set to True, and this detail
# is exploited in the computation of the `contact_status` variable.
sticking = jnp.logical_or(
δ <= 0, f_tangential.dot(f_tangential) <= (μ * force_normal_mag) ** 2
)
# Compute the direction of the tangential force.
# To prevent dividing by zero, we use a switch statement.
norm = jaxsim.math.safe_norm(f_tangential)
f_tangential_direction = f_tangential / (
norm + jnp.finfo(float).eps * (norm == 0)
)
# Project the tangential force to the friction cone if slipping.
f_tangential = jnp.where(
sticking,
f_tangential,
jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction,
)
# Set the tangential force to zero if there is no contact.
f_tangential = jnp.where(δ <= 0, jnp.zeros(3), f_tangential)
# =====================================
# Compute the material deformation rate
# =====================================
# Compute the derivative of the material deformation.
# Note that we included an additional relaxation of `m_normal` in the
# sticking case, so that the normal deformation that could have accumulated
# from a previous slipping phase can relax to zero.
ṁ_no_contact = -(K / D) * m
ṁ_sticking = v_tangential - (K / D) * m_normal
ṁ_slipping = -(f_tangential + (K * δp) * m_tangential) / (D * δq)
# Compute the contact status:
# 0: slipping
# 1: sticking
# 2: no contact
contact_status = sticking.astype(int)
contact_status += (δ <= 0).astype(int)
# Select the right material deformation rate depending on the contact status.
ṁ = jax.lax.select_n(contact_status, ṁ_slipping, ṁ_sticking, ṁ_no_contact)
# ==========================================
# Compute and return the final contact force
# ==========================================
# Sum the normal and tangential forces.
CW_fl = f_normal + f_tangential
return CW_fl, ṁ
@staticmethod
@functools.partial(jax.jit, static_argnames=("terrain",))
def compute_contact_force(
position: jtp.VectorLike,
velocity: jtp.VectorLike,
tangential_deformation: jtp.VectorLike,
parameters: SoftContactsParams,
terrain: Terrain,
) -> tuple[jtp.Vector, jtp.Vector]:
"""
Compute the contact force.
Args:
position: The position of the collidable point.
velocity: The velocity of the collidable point.
tangential_deformation: The material deformation of the collidable point.
parameters: The parameters of the soft contacts model.
terrain: The terrain model.
Returns:
A tuple containing the computed contact force and the derivative of the
material deformation.
"""
CW_fl, ṁ = SoftContacts.hunt_crossley_contact_model(
position=position,
velocity=velocity,
tangential_deformation=tangential_deformation,
terrain=terrain,
K=parameters.K,
D=parameters.D,
mu=parameters.mu,
p=parameters.p,
q=parameters.q,
)
# Pack a mixed 6D force.
CW_f = jnp.hstack([CW_fl, jnp.zeros(3)])
# Compute the 6D force transform from the mixed to the inertial-fixed frame.
W_Xf_CW = jaxsim.math.Adjoint.from_quaternion_and_translation(
translation=jnp.array(position), inverse=True
).T
# Compute the 6D force in the inertial-fixed frame.
W_f = W_Xf_CW @ CW_f
return W_f, ṁ
@staticmethod
@jax.jit
def compute_contact_forces(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
"""
Compute the contact forces.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
A tuple containing as first element the computed contact forces, and as
second element a dictionary with derivative of the material deformation.
"""
# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)
# Compute the position and linear velocities (mixed representation) of
# all the collidable points belonging to the robot and extract the ones
# for the enabled collidable points.
W_p_C, W_ṗ_C = js.contact.collidable_point_kinematics(model=model, data=data)
# Extract the material deformation corresponding to the collidable points.
m = (
data.contact_state["tangential_deformation"]
if "tangential_deformation" in data.contact_state
else jnp.zeros_like(W_p_C)
)
m_enabled = m[indices_of_enabled_collidable_points]
# Initialize the tangential deformation rate array for every collidable point.
ṁ = jnp.zeros_like(m)
# Compute the contact forces only for the enabled collidable points.
# Since we treat them as independent, we can vmap the computation.
W_f, ṁ_enabled = jax.vmap(
lambda p, v, m: SoftContacts.compute_contact_force(
position=p,
velocity=v,
tangential_deformation=m,
parameters=model.contact_params,
terrain=model.terrain,
)
)(W_p_C, W_ṗ_C, m_enabled)
ṁ = ṁ.at[indices_of_enabled_collidable_points].set(ṁ_enabled)
return W_f, {"m_dot": ṁ}
================================================
FILE: src/jaxsim/rbda/crba.py
================================================
import jax
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.typing as jtp
from . import utils
def crba(
model: js.model.JaxSimModel,
*,
joint_positions: jtp.Vector,
) -> jtp.Matrix:
"""
Compute the free-floating mass matrix using the Composite Rigid-Body Algorithm (CRBA).
Args:
model: The model to consider.
joint_positions: The positions of the joints.
Returns:
The free-floating mass matrix of the model in body-fixed representation.
"""
_, _, s, _, _, _, _, _, _, _ = utils.process_inputs(
model=model, joint_positions=joint_positions
)
# Get the 6D spatial inertia matrices of all links.
Mc = js.model.link_spatial_inertia_matrices(model=model)
# Get the parent array λ(i).
# Note: λ(0) must not be used, it's initialized to -1.
λ = model.kin_dyn_parameters.parent_array
# Compute the parent-to-child adjoints of the joints.
# These transforms define the relative kinematics of the entire model, including
# the base transform for both floating-base and fixed-base models.
i_X_λi = model.kin_dyn_parameters.joint_transforms(
joint_positions=s, base_transform=jnp.eye(4)
)
# Extract the joint motion subspaces.
S = model.kin_dyn_parameters.motion_subspaces
# Allocate the buffer of transforms link -> base.
i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
i_X_0 = i_X_0.at[0].set(jnp.eye(6))
# ====================
# Propagate kinematics
# ====================
ForwardPassCarry = tuple[jtp.Matrix]
forward_pass_carry: ForwardPassCarry = (i_X_0,)
def propagate_kinematics(
carry: ForwardPassCarry, i: jtp.Int
) -> tuple[ForwardPassCarry, None]:
(i_X_0,) = carry
i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
i_X_0 = i_X_0.at[i].set(i_X_0_i)
return (i_X_0,), None
(i_X_0,), _ = (
jax.lax.scan(
f=propagate_kinematics,
init=forward_pass_carry,
xs=jnp.arange(start=1, stop=model.number_of_links()),
)
if model.number_of_links() > 1
else [(i_X_0,), None]
)
# ===================
# Compute mass matrix
# ===================
M = jnp.zeros(shape=(6 + model.dofs(), 6 + model.dofs()))
BackwardPassCarry = tuple[jtp.Matrix, jtp.Matrix]
backward_pass_carry: BackwardPassCarry = (Mc, M)
def backward_pass(
carry: BackwardPassCarry, i: jtp.Int
) -> tuple[BackwardPassCarry, None]:
ii = i - 1
Mc, M = carry
Mc_λi = Mc[λ[i]] + i_X_λi[i].T @ Mc[i] @ i_X_λi[i]
Mc = Mc.at[λ[i]].set(Mc_λi)
Fi = Mc[i] @ S[i]
M_ii = S[i].T @ Fi
M = M.at[ii + 6, ii + 6].set(M_ii.squeeze())
j = i
FakeWhileCarry = tuple[jtp.Int, jtp.Vector, jtp.Matrix]
fake_while_carry = (j, Fi, M)
# This internal for loop implements the while loop of the CRBA algorithm
# to compute off-diagonal blocks of the mass matrix M.
# In pseudocode it is implemented as a while loop. However, in order to enable
# applying reverse-mode AD, we implement it as a nested for loop with a fixed
# number of iterations and a branching model to skip for loop iterations.
def fake_while_loop(
carry: FakeWhileCarry, i: jtp.Int
) -> tuple[FakeWhileCarry, None]:
def compute(carry: FakeWhileCarry) -> FakeWhileCarry:
j, Fi, M = carry
Fi = i_X_λi[j].T @ Fi
j = λ[j]
M_ij = Fi.T @ S[j]
jj = j - 1
M = M.at[ii + 6, jj + 6].set(M_ij.squeeze())
M = M.at[jj + 6, ii + 6].set(M_ij.squeeze())
return j, Fi, M
j, _, _ = carry
j, Fi, M = jax.lax.cond(
pred=jnp.logical_and(i == λ[j], λ[j] > 0),
true_fun=compute,
false_fun=lambda carry: carry,
operand=carry,
)
return (j, Fi, M), None
(j, Fi, M), _ = (
jax.lax.scan(
f=fake_while_loop,
init=fake_while_carry,
xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
)
if model.number_of_links() > 1
else [(j, Fi, M), None]
)
Fi = i_X_0[j].T @ Fi
M = M.at[0:6, ii + 6].set(Fi.squeeze())
M = M.at[ii + 6, 0:6].set(Fi.squeeze())
return (Mc, M), None
# This scan performs the backward pass to compute Mbj, Mjb and Mjj, that
# also includes a fake while loop implemented with a scan and two cond.
(Mc, M), _ = (
jax.lax.scan(
f=backward_pass,
init=backward_pass_carry,
xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
)
if model.number_of_links() > 1
else [(Mc, M), None]
)
# Store the locked 6D rigid-body inertia matrix Mbb ∈ ℝ⁶ˣ⁶.
M = M.at[0:6, 0:6].set(Mc[0])
return M
================================================
FILE: src/jaxsim/rbda/forward_kinematics.py
================================================
import jax
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import Adjoint
from . import utils
def forward_kinematics_model(
model: js.model.JaxSimModel,
*,
base_position: jtp.VectorLike,
base_quaternion: jtp.VectorLike,
joint_positions: jtp.VectorLike,
base_linear_velocity_inertial: jtp.VectorLike,
base_angular_velocity_inertial: jtp.VectorLike,
joint_velocities: jtp.VectorLike,
joint_transforms: jtp.MatrixLike,
) -> tuple[jtp.Array, jtp.Array]:
"""
Compute the forward kinematics.
Args:
model: The model to consider.
base_position: The position of the base link.
base_quaternion: The quaternion of the base link.
joint_positions: The positions of the joints.
base_linear_velocity_inertial: The linear velocity of the base link in inertial-fixed representation.
base_angular_velocity_inertial: The angular velocity of the base link in inertial-fixed representation.
joint_velocities: The velocities of the joints.
joint_transforms: The parent-to-child transforms of the joints.
Returns:
A tuple containing the SE(3) transforms and the 6D velocities of all links.
"""
_, _, _, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs(
model=model,
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
base_linear_velocity=base_linear_velocity_inertial,
base_angular_velocity=base_angular_velocity_inertial,
joint_velocities=joint_velocities,
)
# Get the parent array λ(i).
# Note: λ(0) must not be used, it's initialized to -1.
λ = model.kin_dyn_parameters.parent_array
# Extract the parent-to-child adjoints of the joints.
i_X_λi = jnp.asarray(joint_transforms)
# Allocate the buffer of transforms world -> link and initialize the base pose.
W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
W_X_i = W_X_i.at[0].set(Adjoint.inverse(i_X_λi[0]))
# Allocate buffer of 6D inertial-fixed velocities and initialize the base velocity.
W_v_Wi = jnp.zeros(shape=(model.number_of_links(), 6))
W_v_Wi = W_v_Wi.at[0].set(W_v_WB)
# Extract the joint motion subspaces.
S = model.kin_dyn_parameters.motion_subspaces
# ========================
# Propagate the kinematics
# ========================
PropagateKinematicsCarry = tuple[jtp.Matrix, jtp.Matrix]
propagate_kinematics_carry: PropagateKinematicsCarry = (W_X_i, W_v_Wi)
def propagate_kinematics(
carry: PropagateKinematicsCarry, i: jtp.Int
) -> tuple[PropagateKinematicsCarry, None]:
ii = i - 1
W_X_i, W_v_Wi = carry
# Compute the parent to child 6D transform.
λi_X_i = Adjoint.inverse(adjoint=i_X_λi[i])
# Compute the world to child 6D transform.
W_Xi_i = W_X_i[λ[i]] @ λi_X_i
W_X_i = W_X_i.at[i].set(W_Xi_i)
# Propagate the 6D velocity.
W_vi_Wi = W_v_Wi[λ[i]] + W_X_i[i] @ (S[i] * ṡ[ii]).squeeze()
W_v_Wi = W_v_Wi.at[i].set(W_vi_Wi)
return (W_X_i, W_v_Wi), None
(W_X_i, W_v_Wi), _ = (
jax.lax.scan(
f=propagate_kinematics,
init=propagate_kinematics_carry,
xs=jnp.arange(start=1, stop=model.number_of_links()),
)
if model.number_of_links() > 1
else [(W_X_i, W_v_Wi), None]
)
return jax.vmap(Adjoint.to_transform)(W_X_i), W_v_Wi
================================================
FILE: src/jaxsim/rbda/forward_kinematics_parallel.py
================================================
import math
import jax
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import Adjoint
from . import utils
def forward_kinematics_model_parallel(
model: js.model.JaxSimModel,
*,
base_position: jtp.VectorLike,
base_quaternion: jtp.VectorLike,
joint_positions: jtp.VectorLike,
base_linear_velocity_inertial: jtp.VectorLike,
base_angular_velocity_inertial: jtp.VectorLike,
joint_velocities: jtp.VectorLike,
joint_transforms: jtp.MatrixLike,
) -> tuple[jtp.Array, jtp.Array]:
"""
Compute forward kinematics using pointer jumping on the kinematic tree.
Uses an associative binary operator on transform-velocity pairs to
compute all world-frame transforms and velocities in O(log D) parallel
steps, where D is the tree depth.
The interface and semantics are identical to
:func:`forward_kinematics_model`, but parallelized via pointer jumping.
"""
_, _, _, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs(
model=model,
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
base_linear_velocity=base_linear_velocity_inertial,
base_angular_velocity=base_angular_velocity_inertial,
joint_velocities=joint_velocities,
)
# Extract the parent-to-child adjoints of the joints.
i_X_λi = jnp.asarray(joint_transforms)
# Extract the joint motion subspaces.
S = model.kin_dyn_parameters.motion_subspaces
n = model.number_of_links()
# Compute local transforms λ(i)_X_i by inverting the child-to-parent adjoints.
L = jax.vmap(Adjoint.inverse)(i_X_λi) # (n, 6, 6)
# Compute local velocity contributions.
ṡ_padded = jnp.concatenate([jnp.zeros(1), jnp.atleast_1d(ṡ.squeeze())]) # (n,)
vJ = (S * ṡ_padded[:, None, None]).squeeze(-1) # (n, 6)
u = jnp.einsum("nij,nj->ni", L, vJ) # (n, 6)
u = u.at[0].set(W_v_WB)
# Get the parent array λ(i) with root self-loop.
# Note: λ(0) is set to 0 to enable root self-referencing.
ptr = jnp.asarray(model.kin_dyn_parameters.parent_array).at[0].set(0)
done = jnp.arange(n) == 0
# Number of pointer-jumping rounds.
n_levels = model.kin_dyn_parameters.level_nodes.shape[0]
n_rounds = max(1, math.ceil(math.log2(max(n_levels, 2))))
# ===============
# Pointer jumping
# ===============
# Each round composes the node state with its current pointer target,
# then doubles the jump distance. After ceil(log2 D) rounds every node
# has accumulated the full root-to-node transform and velocity.
def _pointer_jump(carry, _):
L, u, ptr, done = carry
need = ~done
L_par = L[ptr]
u_par = u[ptr]
# Associative compose.
L_new = jnp.where(need[:, None, None], L_par @ L, L)
u_new = jnp.where(
need[:, None],
u_par + jnp.einsum("nij,nj->ni", L_par, u),
u,
)
ptr_new = jnp.where(need, ptr[ptr], ptr)
done_new = done | done[ptr]
return (L_new, u_new, ptr_new, done_new), None
(W_X_i, W_v_Wi, _, _), _ = (
jax.lax.scan(
f=_pointer_jump,
init=(L, u, ptr, done),
xs=jnp.arange(n_rounds),
)
if n > 1
else ((L, u, ptr, done), None)
)
return jax.vmap(Adjoint.to_transform)(W_X_i), W_v_Wi
================================================
FILE: src/jaxsim/rbda/jacobian.py
================================================
import jax
import jax.numpy as jnp
import numpy as np
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import Adjoint, Cross
from . import utils
def jacobian(
model: js.model.JaxSimModel,
*,
link_index: jtp.Int,
joint_positions: jtp.VectorLike,
) -> jtp.Matrix:
"""
Compute the free-floating Jacobian of a link.
Args:
model: The model to consider.
link_index: The index of the link for which to compute the Jacobian matrix.
joint_positions: The positions of the joints.
Returns:
The free-floating left-trivialized Jacobian of the link :math:`{}^L J_{W,L/B}`.
"""
_, _, s, _, _, _, _, _, _, _ = utils.process_inputs(
model=model, joint_positions=joint_positions
)
# Get the parent array λ(i).
# Note: λ(0) must not be used, it's initialized to -1.
λ = model.kin_dyn_parameters.parent_array
# Compute the parent-to-child adjoints of the joints.
# These transforms define the relative kinematics of the entire model, including
# the base transform for both floating-base and fixed-base models.
i_X_λi = model.kin_dyn_parameters.joint_transforms(
joint_positions=s, base_transform=jnp.eye(4)
)
# Extract the joint motion subspaces.
S = model.kin_dyn_parameters.motion_subspaces
# Allocate the buffer of transforms link -> base.
i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
i_X_0 = i_X_0.at[0].set(jnp.eye(6))
# ====================
# Propagate kinematics
# ====================
PropagateKinematicsCarry = tuple[jtp.Matrix]
propagate_kinematics_carry: PropagateKinematicsCarry = (i_X_0,)
def propagate_kinematics(
carry: PropagateKinematicsCarry, i: jtp.Int
) -> tuple[PropagateKinematicsCarry, None]:
(i_X_0,) = carry
# Compute the base (0) to link (i) adjoint matrix.
# This works fine since we traverse the kinematic tree following the link
# indices assigned with BFS.
i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
i_X_0 = i_X_0.at[i].set(i_X_0_i)
return (i_X_0,), None
(i_X_0,), _ = (
jax.lax.scan(
f=propagate_kinematics,
init=propagate_kinematics_carry,
xs=np.arange(start=1, stop=model.number_of_links()),
)
if model.number_of_links() > 1
else [(i_X_0,), None]
)
# ============================
# Compute doubly-left Jacobian
# ============================
J = jnp.zeros(shape=(6, 6 + model.dofs()))
Jb = i_X_0[link_index]
J = J.at[0:6, 0:6].set(Jb)
# To make JIT happy, we operate on a boolean version of κ(i).
# Checking if j ∈ κ(i) is equivalent to: κ_bool(j) is True.
κ_bool = model.kin_dyn_parameters.support_body_array_bool[link_index]
def compute_jacobian(J: jtp.Matrix, i: jtp.Int) -> tuple[jtp.Matrix, None]:
def update_jacobian(J: jtp.Matrix, i: jtp.Int) -> jtp.Matrix:
ii = i - 1
Js_i = i_X_0[link_index] @ Adjoint.inverse(i_X_0[i]) @ S[i]
J = J.at[0:6, 6 + ii].set(Js_i.squeeze())
return J
J = jax.lax.select(
pred=κ_bool[i],
on_true=update_jacobian(J, i),
on_false=J,
)
return J, None
L_J_WL_B, _ = (
jax.lax.scan(
f=compute_jacobian,
init=J,
xs=np.arange(start=1, stop=model.number_of_links()),
)
if model.number_of_links() > 1
else [J, None]
)
return L_J_WL_B
@jax.jit
def jacobian_full_doubly_left(
model: js.model.JaxSimModel,
*,
joint_positions: jtp.VectorLike,
) -> tuple[jtp.Matrix, jtp.Array]:
r"""
Compute the doubly-left full free-floating Jacobian of a model.
The full Jacobian is a 6x(6+n) matrix with all the columns filled.
It is useful to run the algorithm once, and then extract the link Jacobian by
filtering the columns of the full Jacobian using the support parent array
:math:`\kappa(i)` of the link.
Args:
model: The model to consider.
joint_positions: The positions of the joints.
Returns:
The doubly-left full free-floating Jacobian of a model.
"""
_, _, s, _, _, _, _, _, _, _ = utils.process_inputs(
model=model, joint_positions=joint_positions
)
# Get the parent array λ(i).
# Note: λ(0) must not be used, it's initialized to -1.
λ = model.kin_dyn_parameters.parent_array
# Compute the parent-to-child adjoints of the joints.
# These transforms define the relative kinematics of the entire model, including
# the base transform for both floating-base and fixed-base models.
i_X_λi = model.kin_dyn_parameters.joint_transforms(
joint_positions=s, base_transform=jnp.eye(4)
)
# Extract the joint motion subspaces.
S = model.kin_dyn_parameters.motion_subspaces
# Allocate the buffer of transforms base -> link.
B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
B_X_i = B_X_i.at[0].set(jnp.eye(6))
# =================================
# Compute doubly-left full Jacobian
# =================================
# Allocate the Jacobian matrix.
# The Jbb section of the doubly-left Jacobian is an identity matrix.
J = jnp.zeros(shape=(6, 6 + model.dofs()))
J = J.at[0:6, 0:6].set(jnp.eye(6))
ComputeFullJacobianCarry = tuple[jtp.Matrix, jtp.Matrix]
compute_full_jacobian_carry: ComputeFullJacobianCarry = (B_X_i, J)
def compute_full_jacobian(
carry: ComputeFullJacobianCarry, i: jtp.Int
) -> tuple[ComputeFullJacobianCarry, None]:
ii = i - 1
B_X_i, J = carry
# Compute the base (0) to link (i) adjoint matrix.
B_Xi_i = B_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i])
B_X_i = B_X_i.at[i].set(B_Xi_i)
# Compute the ii-th column of the B_S_BL(s) matrix.
B_Sii_BL = B_Xi_i @ S[i]
J = J.at[0:6, 6 + ii].set(B_Sii_BL.squeeze())
return (B_X_i, J), None
(B_X_i, J), _ = (
jax.lax.scan(
f=compute_full_jacobian,
init=compute_full_jacobian_carry,
xs=np.arange(start=1, stop=model.number_of_links()),
)
if model.number_of_links() > 1
else [(B_X_i, J), None]
)
# Convert adjoints to SE(3) transforms.
# Returning them here prevents calling FK in case the output representation
# of the Jacobian needs to be changed.
B_H_L = jax.vmap(Adjoint.to_transform)(B_X_i)
# Adjust shape of doubly-left free-floating full Jacobian.
B_J_full_WL_B = J.squeeze().astype(float)
return B_J_full_WL_B, B_H_L
def jacobian_derivative_full_doubly_left(
model: js.model.JaxSimModel,
*,
joint_positions: jtp.VectorLike,
joint_velocities: jtp.VectorLike,
) -> tuple[jtp.Matrix, jtp.Array]:
r"""
Compute the derivative of the doubly-left full free-floating Jacobian of a model.
The derivative of the full Jacobian is a 6x(6+n) matrix with all the columns filled.
It is useful to run the algorithm once, and then extract the link Jacobian
derivative by filtering the columns of the full Jacobian using the support
parent array :math:`\kappa(i)` of the link.
Args:
model: The model to consider.
joint_positions: The positions of the joints.
joint_velocities: The velocities of the joints.
Returns:
The derivative of the doubly-left full free-floating Jacobian of a model.
"""
_, _, s, _, ṡ, _, _, _, _, _ = utils.process_inputs(
model=model, joint_positions=joint_positions, joint_velocities=joint_velocities
)
# Get the parent array λ(i).
# Note: λ(0) must not be used, it's initialized to -1.
λ = model.kin_dyn_parameters.parent_array
# Compute the parent-to-child adjoints of the joints.
# These transforms define the relative kinematics of the entire model, including
# the base transform for both floating-base and fixed-base models.
i_X_λi = model.kin_dyn_parameters.joint_transforms(
joint_positions=s, base_transform=jnp.eye(4)
)
# Extract the joint motion subspaces.
S = model.kin_dyn_parameters.motion_subspaces
# Allocate the buffer of 6D transform base -> link.
B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
B_X_i = B_X_i.at[0].set(jnp.eye(6))
# Allocate the buffer of 6D transform derivatives base -> link.
B_Ẋ_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))
# Allocate the buffer of the 6D link velocity in body-fixed representation.
B_v_Bi = jnp.zeros(shape=(model.number_of_links(), 6))
# Helper to compute the time derivative of the adjoint matrix.
def A_Ẋ_B(A_X_B: jtp.Matrix, B_v_AB: jtp.Vector) -> jtp.Matrix:
return A_X_B @ Cross.vx(B_v_AB).squeeze()
# ============================================
# Compute doubly-left full Jacobian derivative
# ============================================
# Allocate the Jacobian matrix.
J̇ = jnp.zeros(shape=(6, 6 + model.dofs()))
ComputeFullJacobianDerivativeCarry = tuple[
jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix
]
compute_full_jacobian_derivative_carry: ComputeFullJacobianDerivativeCarry = (
B_v_Bi,
B_X_i,
B_Ẋ_i,
J̇,
)
def compute_full_jacobian_derivative(
carry: ComputeFullJacobianDerivativeCarry, i: jtp.Int
) -> tuple[ComputeFullJacobianDerivativeCarry, None]:
ii = i - 1
B_v_Bi, B_X_i, B_Ẋ_i, J̇ = carry
# Compute the base (0) to link (i) adjoint matrix.
B_Xi_i = B_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i])
B_X_i = B_X_i.at[i].set(B_Xi_i)
# Compute the body-fixed velocity of the link.
B_vi_Bi = B_v_Bi[λ[i]] + B_X_i[i] @ S[i].squeeze() * ṡ[ii]
B_v_Bi = B_v_Bi.at[i].set(B_vi_Bi)
# Compute the base (0) to link (i) adjoint matrix derivative.
i_Xi_B = Adjoint.inverse(B_Xi_i)
B_Ẋi_i = A_Ẋ_B(A_X_B=B_Xi_i, B_v_AB=i_Xi_B @ B_vi_Bi)
B_Ẋ_i = B_Ẋ_i.at[i].set(B_Ẋi_i)
# Compute the ii-th column of the B_Ṡ_BL(s) matrix.
B_Ṡii_BL = B_Ẋ_i[i] @ S[i]
J̇ = J̇.at[0:6, 6 + ii].set(B_Ṡii_BL.squeeze())
return (B_v_Bi, B_X_i, B_Ẋ_i, J̇), None
(_, B_X_i, B_Ẋ_i, J̇), _ = (
jax.lax.scan(
f=compute_full_jacobian_derivative,
init=compute_full_jacobian_derivative_carry,
xs=np.arange(start=1, stop=model.number_of_links()),
)
if model.number_of_links() > 1
else [(_, B_X_i, B_Ẋ_i, J̇), None]
)
# Convert adjoints to SE(3) transforms.
# Returning them here prevents calling FK in case the output representation
# of the Jacobian needs to be changed.
B_H_L = jax.vmap(Adjoint.to_transform)(B_X_i)
# Adjust shape of doubly-left free-floating full Jacobian derivative.
B_J̇_full_WL_B = J̇.squeeze().astype(float)
return B_J̇_full_WL_B, B_H_L
================================================
FILE: src/jaxsim/rbda/kinematic_constraints.py
================================================
from __future__ import annotations
import jax
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
from jaxsim.api.kin_dyn_parameters import ConstraintMap
from jaxsim.math.adjoint import Adjoint
from jaxsim.math.rotation import Rotation
from jaxsim.math.transform import Transform
# Utility functions used for constraints computation. These functions duplicate part of the jaxsim.api.frame module for computational efficiency.
# TODO: remove these functions when jaxsim.api.frame is optimized for batched computations.
# See: https://github.com/gbionics/jaxsim/issues/451
def _compute_constraint_transforms_batched(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
constraints: ConstraintMap,
) -> jtp.Matrix:
"""
Compute the transformation matrices for kinematic constraints between pairs of frames.
Args:
model: The JaxSim model containing the robot description.
data: The model data containing current state information.
constraints: The constraint map containing frame indices and parent link information.
Returns:
A matrix with shape (n_constraints, 2, 4, 4) containing the transformation matrices
for each constraint pair. The second dimension contains [W_H_F1, W_H_F2] where
W_H_F1 and W_H_F2 are the world-to-frame transformation matrices.
"""
W_H_L = data._link_transforms
frame_idxs_1 = constraints.frame_idxs_1
frame_idxs_2 = constraints.frame_idxs_2
parent_link_idxs_1 = constraints.parent_link_idxs_1
parent_link_idxs_2 = constraints.parent_link_idxs_2
# Extract frame transforms
L_H_F1 = model.kin_dyn_parameters.frame_parameters.transform[
frame_idxs_1 - model.number_of_links()
]
L_H_F2 = model.kin_dyn_parameters.frame_parameters.transform[
frame_idxs_2 - model.number_of_links()
]
# Compute the homogeneous transformation matrices for the two frames
W_H_F1 = W_H_L[parent_link_idxs_1] @ L_H_F1
W_H_F2 = W_H_L[parent_link_idxs_2] @ L_H_F2
return jnp.stack([W_H_F1, W_H_F2], axis=1)
def _compute_constraint_jacobians_batched(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
constraints: ConstraintMap,
W_H_constraint_pairs: jtp.Matrix,
) -> jtp.Matrix:
"""
Compute the constraint Jacobian matrices for kinematic constraints in a batched manner.
Args:
model: The JaxSim model containing the robot description.
data: The model data containing current state information.
constraints: The constraint map containing frame indices and parent link information.
W_H_constraint_pairs: Transformation matrices for constraint frame pairs with shape
(n_constraints, 2, 4, 4).
Returns:
A matrix with shape (n_constraints, 6, n_dofs) containing the constraint Jacobian
matrices.
"""
with data.switch_velocity_representation(VelRepr.Body):
# Doubly-left free-floating Jacobian.
L_J_WL_B = js.model.generalized_free_floating_jacobian(
model=model, data=data, output_vel_repr=VelRepr.Body
)
# Link transforms
W_H_L = data._link_transforms
def compute_frame_jacobian_mixed(L_J_WL, W_H_L, W_H_F, parent_link_index):
"""Compute the jacobian of a frame in mixed representation."""
# Select the jacobian of the parent link
L_J_WL = L_J_WL[parent_link_index]
# Compute the jacobian of the frame in mixed representation
W_H_L = W_H_L[parent_link_index]
F_H_L = Transform.inverse(W_H_F) @ W_H_L
FW_H_F = W_H_F.at[0:3, 3].set(jnp.zeros(3))
FW_H_L = FW_H_F @ F_H_L
FW_X_L = Adjoint.from_transform(transform=FW_H_L)
FW_J_WL = FW_X_L @ L_J_WL
O_J_WL_I = FW_J_WL
return O_J_WL_I
def compute_constraint_jacobian(L_J_WL, W_H_F, constraint):
"""Compute the constraint jacobian for a single constraint pair."""
J_WF1 = compute_frame_jacobian_mixed(
L_J_WL, W_H_L, W_H_F[0], constraint.parent_link_idxs_1
)
J_WF2 = compute_frame_jacobian_mixed(
L_J_WL, W_H_L, W_H_F[1], constraint.parent_link_idxs_2
)
return J_WF1 - J_WF2
# Vectorize the computation of constraint Jacobians
constraint_jacobians = jax.vmap(compute_constraint_jacobian, in_axes=(None, 0, 0))(
L_J_WL_B, W_H_constraint_pairs, constraints
)
return constraint_jacobians
def _compute_constraint_baumgarte_term(
J_constr: jtp.Matrix,
nu: jtp.Vector,
W_H_F_constr: jtp.Matrix,
constraint: ConstraintMap,
) -> jtp.Vector:
"""
Compute the Baumgarte stabilization term for kinematic constraints.
The Baumgarte stabilization method is used to stabilize constraint violations
by adding proportional and derivative terms to the constraint equation. This
helps prevent constraint drift and improves numerical stability.
Args:
J_constr: The constraint Jacobian matrix with shape (6, n_dofs).
nu: The generalized velocity vector with shape (n_dofs,).
W_H_F_constr: Array containing the homogeneous transformation matrices
of two frames [W_H_F1, W_H_F2] with respect to the world frame,
with shape (2, 4, 4).
constraint: The constraint object containing stabilization gains K_P and K_D.
Returns:
The computed Baumgarte stabilization term.
"""
W_H_F1, W_H_F2 = W_H_F_constr
W_p_F1 = W_H_F1[0:3, 3]
W_p_F2 = W_H_F2[0:3, 3]
W_R_F1 = W_H_F1[0:3, 0:3]
W_R_F2 = W_H_F2[0:3, 0:3]
K_P = constraint.K_P
K_D = constraint.K_D
vel_error = J_constr @ nu
position_error = W_p_F1 - W_p_F2
R_error = W_R_F2.T @ W_R_F1
orientation_error = Rotation.log_vee(R_error)
baumgarte_term = (
K_P * jnp.concatenate([position_error, orientation_error]) + K_D * vel_error
)
return baumgarte_term
@jax.jit
@js.common.named_scope
def compute_constraint_wrenches(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
joint_force_references: jtp.VectorLike | None = None,
link_forces_inertial: jtp.MatrixLike | None = None,
regularization: jtp.Float = 1e-3,
) -> jtp.Matrix:
"""
Compute the constraint wrenches for kinematic constraints.
This function solves the constraint forces needed to satisfy kinematic constraints
between pairs of frames. It uses the Baumgarte stabilization method and computes
the constraint wrenches in inertial representation.
Args:
model: The JaxSim model.
data: The model data.
joint_force_references: Optional joint force/torque references to apply. If None,
zero forces are used.
link_forces_inertial: Optional link forces applied in inertial representation.
If None, zero forces are used.
regularization: Regularization parameter for the constraint solver to improve
numerical stability. Default is 1e-3.
Returns:
Array with shape (n_constraints, 2, 6) containing constraint wrench pairs
in inertial representation. Each constraint produces two equal and opposite
wrenches applied to the constrained frames.
"""
# Retrieve the kinematic constraints, if any.
kin_constraints = model.kin_dyn_parameters.constraints
n_kin_constraints = (
6 * kin_constraints.frame_idxs_1.shape[0]
if kin_constraints is not None and kin_constraints.frame_idxs_1.shape[0] > 0
else 0
)
# Return empty results if no constraints exist
if n_kin_constraints == 0:
return jnp.zeros((0, 2, 6))
# Build joint forces if not provided
τ_references = (
jnp.asarray(joint_force_references, dtype=float)
if joint_force_references is not None
else jnp.zeros_like(data.joint_positions)
)
# Build link forces if not provided
W_f_L = (
jnp.atleast_2d(jnp.array(link_forces_inertial).squeeze())
if link_forces_inertial is not None
else jnp.zeros((model.number_of_links(), 6))
).astype(float)
# Create references object for handling different velocity representations
references = js.references.JaxSimModelReferences.build(
model=model,
joint_force_references=τ_references,
link_forces=W_f_L,
velocity_representation=VelRepr.Inertial,
)
with (
data.switch_velocity_representation(VelRepr.Mixed),
references.switch_velocity_representation(VelRepr.Mixed),
):
BW_ν = data.generalized_velocity
# Compute free acceleration without constraints
BW_ν̇_free = jnp.hstack(
js.model.forward_dynamics_aba(
model=model,
data=data,
link_forces=references.link_forces(model=model, data=data),
joint_forces=references.joint_force_references(model=model),
)
)
# Compute mass matrix
M_inv = js.model.free_floating_mass_matrix_inverse(model=model, data=data)
W_H_constr_pairs = _compute_constraint_transforms_batched(
model=model,
data=data,
constraints=kin_constraints,
)
# Compute constraint jacobians
J_constr = _compute_constraint_jacobians_batched(
model=model,
data=data,
constraints=kin_constraints,
W_H_constraint_pairs=W_H_constr_pairs,
)
# Compute Baumgarte stabilization term
constr_baumgarte_term = jnp.ravel(
jax.vmap(
_compute_constraint_baumgarte_term,
in_axes=(0, None, 0, 0),
)(
J_constr,
BW_ν,
W_H_constr_pairs,
kin_constraints,
),
)
# Stack constraint jacobians
J_constr = jnp.vstack(J_constr)
# Compute Delassus matrix for constraints
G_constraints = J_constr @ M_inv @ J_constr.T
# Compute constraint acceleration
# TODO: add J̇_constr with efficient computation
CW_al_free_constr = J_constr @ BW_ν̇_free
# Setup constraint optimization problem
constraint_regularization = regularization * jnp.ones(n_kin_constraints)
R = jnp.diag(constraint_regularization)
A = G_constraints + R
b = CW_al_free_constr + constr_baumgarte_term
# Solve for constraint forces
kin_constr_wrench_mixed = jnp.linalg.solve(A, -b).reshape(-1, 6)
def transform_wrenches_to_inertial(wrench, transform_pair):
"""
Transform wrench pairs in inertial representation.
Args:
wrench: Wrench vector with shape (6,).
transform_pair: Pair of transformation matrices [W_H_F1, W_H_F2]
Returns:
Stack of transformed wrenches with shape (2, 6).
"""
W_H_F1, W_H_F2 = transform_pair[0], transform_pair[1]
wrench_F1 = wrench
wrench_F2 = -wrench
# Create wrench pair directly
# Transform both at once
wrench_F1_inertial = (
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
array=wrench_F1,
transform=W_H_F1,
other_representation=VelRepr.Mixed,
is_force=True,
)
)
wrench_F2_inertial = (
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
array=wrench_F2,
transform=W_H_F2,
other_representation=VelRepr.Mixed,
is_force=True,
)
)
return jnp.stack([wrench_F1_inertial, wrench_F2_inertial])
kin_constr_wrench_pairs_inertial = jax.vmap(transform_wrenches_to_inertial)(
kin_constr_wrench_mixed, W_H_constr_pairs
)
return kin_constr_wrench_pairs_inertial
================================================
FILE: src/jaxsim/rbda/mass_inverse.py
================================================
import jax
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.typing as jtp
def mass_inverse(
model: js.model.JaxSimModel,
*,
joint_transforms: jtp.MatrixLike,
) -> jtp.Matrix:
"""
Compute the inverse of the mass matrix using an ABA-like algorithm.
The implementation follows the approach described in https://laas.hal.science/hal-01790934v2.
Args:
model: The model to consider.
joint_transforms: The parent-to-child transforms of the joints.
Returns:
The inverse of the mass matrix.
"""
# Get the 6D spatial inertia matrices of all links.
I_A = js.model.link_spatial_inertia_matrices(model=model)
# Get the parent array λ(i).
# λ[0] ~ -1 (world)
# λ[i] = parent link index for link i.
λ = model.kin_dyn_parameters.parent_array
# Extract the parent-to-child adjoints of the joints.
# These transforms define the relative kinematics of the entire model, including
# the base transform for both floating-base and fixed-base models.
i_X_λi = jnp.asarray(joint_transforms)
# Extract the joint motion subspaces.
S = model.kin_dyn_parameters.motion_subspaces
NB = model.number_of_links()
N = model.number_of_joints()
# Total generalized velocities: 6 base + N.
nv = N + 6
# Allocate buffers.
F = jnp.zeros((NB, 6, nv), dtype=float)
P = jnp.zeros((NB, 6, nv), dtype=float)
U = jnp.zeros((NB, 6), dtype=float)
D = jnp.zeros((NB,), dtype=float)
# Pre-allocate mass matrix inverse
M_inv = jnp.zeros((nv, nv), dtype=float)
# Pre-compute indices.
idx_fwd = jnp.arange(1, NB)
idx_rev = jnp.arange(NB - 1, 0, -1)
# =============
# Backward Pass
# =============
BackwardPassCarry = tuple[
jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix
]
backward_pass_carry: BackwardPassCarry = (I_A, F, U, D, M_inv)
def loop_backward_pass(
carry: BackwardPassCarry, i: jtp.Int
) -> tuple[BackwardPassCarry, None]:
I_A, F, U, D, M_inv = carry
Si = jnp.squeeze(S[i], axis=-1)
Fi = F[i]
Xi = i_X_λi[i]
parent = λ[i]
Ui = I_A[i] @ Si
Di = jnp.dot(Si, Ui)
U = U.at[i].set(Ui)
D = D.at[i].set(Di)
# Row index in ν for joint i: 6 + (i - 1)
r = 6 + (i - 1)
Minv_row = M_inv[r]
# Diagonal element
Minv_row = Minv_row.at[r].add(1.0 / Di)
# Off-diagonals: Minv[r,:] -= (1/Di) * Sᵢᵀ Fᵢ
sTFi = jnp.einsum("s,sn->n", Si, Fi)
Minv_row = Minv_row - sTFi / Di
M_inv = M_inv.at[r].set(Minv_row)
# Propagate to parent if any (parent >= 0)
def propagate(IA_F):
I_A_, F_ = IA_F
Ui_col = Ui[:, None]
# F_a_i = F_i + U_i * Minv[r,:]
Fa_i = Fi + Ui_col @ Minv_row[None, :]
# F_parent += Xᵢᵀ F_a_i
F_parent_new = F_[parent] + Xi.T @ Fa_i
F_ = F_.at[parent].set(F_parent_new)
# I_a_i = IAi - U_i D_i^{-1} U_iᵀ
Ia_i = I_A[i] - jnp.outer(Ui, Ui) / Di
# I_A[parent] += Xᵢᵀ I_a_i Xᵢ
I_parent_new = I_A_[parent] + Xi.T @ Ia_i @ Xi
I_A_ = I_A_.at[parent].set(I_parent_new)
return I_A_, F_
I_A, F = jax.lax.cond(
parent >= 0,
propagate,
lambda IA_F: IA_F,
(I_A, F),
)
return (I_A, F, U, D, M_inv), None
(I_A, F, U, D, M_inv), _ = jax.lax.scan(
loop_backward_pass, backward_pass_carry, idx_rev
)
S0 = jnp.eye(6, dtype=float)
U0 = I_A[0] @ S0
D0 = S0.T @ U0
D0_inv = jnp.linalg.inv(D0)
# Base rows 0..5 in ν
base_rows = slice(0, 6)
# Diagonal base block
M_inv = M_inv.at[base_rows, base_rows].add(D0_inv)
# Off-diagonal base contribution: M_inv[base,:] -= D0^{-T} F[0]
term0 = D0_inv.T @ F[0]
M_inv = M_inv.at[base_rows, :].add(-term0)
# ============
# Forward Pass
# ============
# Initialize P_0 = S0 * Minv[base,:] = I * Minv[base,:]
Minv_base = M_inv[base_rows, :]
P = P.at[0].set(Minv_base)
ForwardPassCarry = tuple[jtp.Matrix, jtp.Matrix]
forward_pass_carry: ForwardPassCarry = (M_inv, P)
def loop_forward_pass(
carry: ForwardPassCarry, i: jtp.Int
) -> tuple[ForwardPassCarry, None]:
M_inv, P = carry
Si = jnp.squeeze(S[i], axis=-1)
Ui = U[i]
Di = D[i]
Xi = i_X_λi[i]
parent = λ[i]
P_parent = jax.lax.cond(
parent >= 0,
lambda P_: P_[parent],
lambda P_: jnp.zeros_like(P_[i]),
P,
)
# Row index in ν for joint i
r = 6 + (i - 1)
# Row update: M_inv[r,:] -= D_i^{-1} U_iᵀ Xᵢ P_parent
def update_row(Minv_):
X_P = Xi @ P_parent
UiT_XP = jnp.einsum("s,sn->n", Ui, X_P)
Minv_row = Minv_[r, :] - UiT_XP / Di
return Minv_.at[r, :].set(Minv_row)
M_inv = jax.lax.cond(
parent >= 0,
update_row,
lambda Minv_: Minv_,
M_inv,
)
Minv_row = M_inv[r, :]
# P_i = S_i Minv[r,:] + Xᵢ P_parent
Pi = jnp.expand_dims(Si, 1) @ jnp.expand_dims(Minv_row, 0)
Pi = Pi + Xi @ P_parent
P = P.at[i].set(Pi)
return (M_inv, P), None
(M_inv, P), _ = jax.lax.scan(loop_forward_pass, forward_pass_carry, idx_fwd)
# Symmetrize numerically
M_inv = 0.5 * (M_inv + M_inv.T)
return M_inv
================================================
FILE: src/jaxsim/rbda/rnea.py
================================================
import jax
import jax.numpy as jnp
import jaxlie
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import STANDARD_GRAVITY, Adjoint, Cross
from . import utils
def rnea(
model: js.model.JaxSimModel,
*,
base_position: jtp.Vector,
base_quaternion: jtp.Vector,
joint_positions: jtp.Vector,
base_linear_velocity: jtp.Vector,
base_angular_velocity: jtp.Vector,
joint_velocities: jtp.Vector,
base_linear_acceleration: jtp.Vector | None = None,
base_angular_acceleration: jtp.Vector | None = None,
joint_accelerations: jtp.Vector | None = None,
joint_transforms: jtp.MatrixLike,
link_forces: jtp.Matrix | None = None,
standard_gravity: jtp.FloatLike = STANDARD_GRAVITY,
) -> tuple[jtp.Vector, jtp.Vector]:
"""
Compute inverse dynamics using the Recursive Newton-Euler Algorithm (RNEA).
Args:
model: The model to consider.
base_position: The position of the base link.
base_quaternion: The quaternion of the base link.
joint_positions: The positions of the joints.
base_linear_velocity:
The linear velocity of the base link in inertial-fixed representation.
base_angular_velocity:
The angular velocity of the base link in inertial-fixed representation.
joint_velocities: The velocities of the joints.
base_linear_acceleration:
The linear acceleration of the base link in inertial-fixed representation.
base_angular_acceleration:
The angular acceleration of the base link in inertial-fixed representation.
joint_accelerations: The accelerations of the joints.
joint_transforms: The parent-to-child transforms of the joints.
link_forces:
The forces applied to the links expressed in the world frame.
standard_gravity: The standard gravity constant.
Returns:
A tuple containing the 6D force applied to the base link expressed in the
world frame and the joint forces that, when applied respectively to the base
link and joints, produce the given base and joint accelerations.
"""
W_p_B, W_Q_B, s, W_v_WB, ṡ, W_v̇_WB, s̈, _, W_f, W_g = utils.process_inputs(
model=model,
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
base_linear_velocity=base_linear_velocity,
base_angular_velocity=base_angular_velocity,
joint_velocities=joint_velocities,
base_linear_acceleration=base_linear_acceleration,
base_angular_acceleration=base_angular_acceleration,
joint_accelerations=joint_accelerations,
link_forces=link_forces,
standard_gravity=standard_gravity,
)
W_g = jnp.atleast_2d(W_g).T
W_v_WB = jnp.atleast_2d(W_v_WB).T
W_v̇_WB = jnp.atleast_2d(W_v̇_WB).T
# Get the 6D spatial inertia matrices of all links.
M = js.model.link_spatial_inertia_matrices(model=model)
# Get the parent array λ(i).
# Note: λ(0) must not be used, it's initialized to -1.
λ = model.kin_dyn_parameters.parent_array
# Compute the base transform.
W_H_B = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3(wxyz=W_Q_B),
translation=W_p_B,
)
# Compute 6D transforms of the base velocity.
W_X_B = W_H_B.adjoint()
B_X_W = W_H_B.inverse().adjoint()
# Extract the parent-to-child adjoints of the joints.
# These transforms define the relative kinematics of the entire model, including
# the base transform for both floating-base and fixed-base models.
# Ensure cached transforms are JAX arrays so they work with traced indices.
i_X_λi = jnp.asarray(joint_transforms)
# Extract the joint motion subspaces.
S = model.kin_dyn_parameters.motion_subspaces
# Allocate buffers.
v = jnp.zeros(shape=(model.number_of_links(), 6, 1))
a = jnp.zeros(shape=(model.number_of_links(), 6, 1))
f = jnp.zeros(shape=(model.number_of_links(), 6, 1))
# Allocate the buffer of transforms link -> base.
i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))
i_X_0 = i_X_0.at[0].set(jnp.eye(6))
# Initialize the acceleration of the base link.
a_0 = -B_X_W @ W_g
a = a.at[0].set(a_0)
if model.floating_base():
# Base velocity v₀ in body-fixed representation.
v_0 = B_X_W @ W_v_WB
v = v.at[0].set(v_0)
# Base acceleration a₀ in body-fixed representation w/o gravity.
a_0 = B_X_W @ (W_v̇_WB - W_g)
a = a.at[0].set(a_0)
# Force applied to the base link that produce the base acceleration w/o gravity.
f_0 = (
M[0] @ a[0]
+ Cross.vx_star(v[0]) @ M[0] @ v[0]
- W_X_B.T @ jnp.vstack(W_f[0])
)
f = f.at[0].set(f_0)
# ======
# Pass 1
# ======
ForwardPassCarry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
forward_pass_carry: ForwardPassCarry = (v, a, i_X_0, f)
def forward_pass(
carry: ForwardPassCarry, i: jtp.Int
) -> tuple[ForwardPassCarry, None]:
ii = i - 1
v, a, i_X_0, f = carry
# Project the joint velocity into its motion subspace.
vJ = S[i] * ṡ[ii]
# Propagate the link velocity.
v_i = i_X_λi[i] @ v[λ[i]] + vJ
v = v.at[i].set(v_i)
# Propagate the link acceleration.
a_i = i_X_λi[i] @ a[λ[i]] + S[i] * s̈[ii] + Cross.vx(v[i]) @ vJ
a = a.at[i].set(a_i)
# Compute the link-to-base transform.
i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
i_X_0 = i_X_0.at[i].set(i_X_0_i)
# Compute link-to-world transform for the 6D force.
i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
# Compute the force acting on the link.
f_i = (
M[i] @ a[i]
+ Cross.vx_star(v[i]) @ M[i] @ v[i]
- i_Xf_W @ jnp.vstack(W_f[i])
)
f = f.at[i].set(f_i)
return (v, a, i_X_0, f), None
(v, a, i_X_0, f), _ = (
jax.lax.scan(
f=forward_pass,
init=forward_pass_carry,
xs=jnp.arange(start=1, stop=model.number_of_links()),
)
if model.number_of_links() > 1
else [(v, a, i_X_0, f), None]
)
# ======
# Pass 2
# ======
τ = jnp.zeros_like(s)
BackwardPassCarry = tuple[jtp.Vector, jtp.Matrix]
backward_pass_carry: BackwardPassCarry = (τ, f)
def backward_pass(
carry: BackwardPassCarry, i: jtp.Int
) -> tuple[BackwardPassCarry, None]:
ii = i - 1
τ, f = carry
# Project the 6D force to the DoF of the joint.
τ_i = S[i].T @ f[i]
τ = τ.at[ii].set(τ_i.squeeze())
# Propagate the force to the parent link.
def update_f(f: jtp.Matrix) -> jtp.Matrix:
f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]
f = f.at[λ[i]].set(f_λi)
return f
f = jax.lax.cond(
pred=jnp.logical_or(λ[i] != 0, model.floating_base()),
true_fun=update_f,
false_fun=lambda f: f,
operand=f,
)
return (τ, f), None
(τ, f), _ = (
jax.lax.scan(
f=backward_pass,
init=backward_pass_carry,
xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
)
if model.number_of_links() > 1
else [(τ, f), None]
)
# ==============
# Adjust outputs
# ==============
# Express the base 6D force in the world frame.
W_f0 = B_X_W.T @ f[0]
return W_f0.squeeze(), jnp.atleast_1d(τ.squeeze())
================================================
FILE: src/jaxsim/rbda/utils.py
================================================
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim import exceptions
from jaxsim.math import STANDARD_GRAVITY
def process_inputs(
model: js.model.JaxSimModel,
*,
base_position: jtp.VectorLike | None = None,
base_quaternion: jtp.VectorLike | None = None,
joint_positions: jtp.VectorLike | None = None,
base_linear_velocity: jtp.VectorLike | None = None,
base_angular_velocity: jtp.VectorLike | None = None,
joint_velocities: jtp.VectorLike | None = None,
base_linear_acceleration: jtp.VectorLike | None = None,
base_angular_acceleration: jtp.VectorLike | None = None,
joint_accelerations: jtp.VectorLike | None = None,
joint_forces: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
standard_gravity: jtp.ScalarLike | None = None,
) -> tuple[
jtp.Vector,
jtp.Vector,
jtp.Vector,
jtp.Vector,
jtp.Vector,
jtp.Vector,
jtp.Vector,
jtp.Vector,
jtp.Matrix,
jtp.Vector,
]:
"""
Adjust the inputs to rigid-body dynamics algorithms.
Args:
model: The model to consider.
base_position: The position of the base link.
base_quaternion: The quaternion of the base link.
joint_positions: The positions of the joints.
base_linear_velocity: The linear velocity of the base link.
base_angular_velocity: The angular velocity of the base link.
joint_velocities: The velocities of the joints.
base_linear_acceleration: The linear acceleration of the base link.
base_angular_acceleration: The angular acceleration of the base link.
joint_accelerations: The accelerations of the joints.
joint_forces: The forces applied to the joints.
link_forces: The forces applied to the links.
standard_gravity: The standard gravity constant.
Returns:
The adjusted inputs.
"""
dofs = model.dofs()
nl = model.number_of_links()
# Floating-base position.
W_p_B = base_position
W_Q_B = base_quaternion
s = joint_positions
# Floating-base velocity in inertial-fixed representation.
W_vl_WB = base_linear_velocity
W_ω_WB = base_angular_velocity
ṡ = joint_velocities
# Floating-base acceleration in inertial-fixed representation.
W_v̇l_WB = base_linear_acceleration
W_ω̇_WB = base_angular_acceleration
s̈ = joint_accelerations
# System dynamics inputs.
f = link_forces
τ = joint_forces
# Fill missing data and adjust dimensions.
s = jnp.atleast_1d(s.squeeze()) if s is not None else jnp.zeros(dofs)
ṡ = jnp.atleast_1d(ṡ.squeeze()) if ṡ is not None else jnp.zeros(dofs)
s̈ = jnp.atleast_1d(s̈.squeeze()) if s̈ is not None else jnp.zeros(dofs)
τ = jnp.atleast_1d(τ.squeeze()) if τ is not None else jnp.zeros(dofs)
W_vl_WB = jnp.atleast_1d(W_vl_WB.squeeze()) if W_vl_WB is not None else jnp.zeros(3)
W_v̇l_WB = (
jnp.atleast_1d(W_v̇l_WB.squeeze()) if W_v̇l_WB is not None else jnp.zeros(3)
)
W_p_B = jnp.atleast_1d(W_p_B.squeeze()) if W_p_B is not None else jnp.zeros(3)
W_ω_WB = jnp.atleast_1d(W_ω_WB.squeeze()) if W_ω_WB is not None else jnp.zeros(3)
W_ω̇_WB = jnp.atleast_1d(W_ω̇_WB.squeeze()) if W_ω̇_WB is not None else jnp.zeros(3)
f = jnp.atleast_2d(f.squeeze()) if f is not None else jnp.zeros(shape=(nl, 6))
W_Q_B = (
jnp.atleast_1d(W_Q_B.squeeze())
if W_Q_B is not None
else jnp.array([1.0, 0, 0, 0])
)
standard_gravity = (
jnp.array(standard_gravity).squeeze()
if standard_gravity is not None
else STANDARD_GRAVITY
)
if s.shape != (dofs,):
raise ValueError(s.shape, dofs)
if ṡ.shape != (dofs,):
raise ValueError(ṡ.shape, dofs)
if s̈.shape != (dofs,):
raise ValueError(s̈.shape, dofs)
if τ.shape != (dofs,):
raise ValueError(τ.shape, dofs)
if W_p_B.shape != (3,):
raise ValueError(W_p_B.shape, (3,))
if W_vl_WB.shape != (3,):
raise ValueError(W_vl_WB.shape, (3,))
if W_ω_WB.shape != (3,):
raise ValueError(W_ω_WB.shape, (3,))
if W_v̇l_WB.shape != (3,):
raise ValueError(W_v̇l_WB.shape, (3,))
if W_ω̇_WB.shape != (3,):
raise ValueError(W_ω̇_WB.shape, (3,))
if f.shape != (nl, 6):
raise ValueError(f.shape, (nl, 6))
if W_Q_B.shape != (4,):
raise ValueError(W_Q_B.shape, (4,))
# Check that the quaternion does not contain NaN values.
exceptions.raise_value_error_if(
condition=jnp.isnan(W_Q_B).any(),
msg="A RBDA received a quaternion that contains NaN values.",
)
# Check that the quaternion is unary since our RBDAs make this assumption in order
# to prevent introducing additional normalizations that would affect AD.
exceptions.raise_value_error_if(
condition=~jnp.allclose(W_Q_B.dot(W_Q_B), 1.0),
msg="A RBDA received a quaternion that is not normalized.",
)
# Pack the 6D base velocity and acceleration.
W_v_WB = jnp.hstack([W_vl_WB, W_ω_WB])
W_v̇_WB = jnp.hstack([W_v̇l_WB, W_ω̇_WB])
# Create the 6D gravity acceleration.
W_g = jnp.array([0, 0, standard_gravity, 0, 0, 0])
return (
W_p_B.astype(float),
W_Q_B.astype(float),
s.astype(float),
W_v_WB.astype(float),
ṡ.astype(float),
W_v̇_WB.astype(float),
s̈.astype(float),
τ.astype(float),
f.astype(float),
W_g.astype(float),
)
================================================
FILE: src/jaxsim/terrain/__init__.py
================================================
from . import terrain
from .terrain import FlatTerrain, PlaneTerrain, Terrain
================================================
FILE: src/jaxsim/terrain/terrain.py
================================================
from __future__ import annotations
import abc
import dataclasses
import jax.numpy as jnp
import jax_dataclasses
import numpy as np
import jaxsim.math
import jaxsim.typing as jtp
from jaxsim import exceptions
class Terrain(abc.ABC):
"""
Base class for terrain models.
Attributes:
delta: The delta value used for numerical differentiation.
"""
delta = 0.010
@abc.abstractmethod
def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
"""
Compute the height of the terrain at a specific (x, y) location.
Args:
x: The x-coordinate of the location.
y: The y-coordinate of the location.
Returns:
The height of the terrain at the specified location.
"""
pass
def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
"""
Compute the normal vector of the terrain at a specific (x, y) location.
Args:
x: The x-coordinate of the location.
y: The y-coordinate of the location.
Returns:
The normal vector of the terrain surface at the specified location.
"""
# https://stackoverflow.com/a/5282364
h_xp = self.height(x=x + self.delta, y=y)
h_xm = self.height(x=x - self.delta, y=y)
h_yp = self.height(x=x, y=y + self.delta)
h_ym = self.height(x=x, y=y - self.delta)
n = jnp.array(
[(h_xm - h_xp) / (2 * self.delta), (h_ym - h_yp) / (2 * self.delta), 1.0]
)
return n / jaxsim.math.safe_norm(n, axis=-1)
@jax_dataclasses.pytree_dataclass
class FlatTerrain(Terrain):
"""
Represents a terrain model with a flat surface and a constant height.
"""
_height: float = dataclasses.field(default=0.0, kw_only=True)
@staticmethod
def build(height: jtp.FloatLike = 0.0) -> FlatTerrain:
"""
Create a FlatTerrain instance with a specified height.
Args:
height: The height of the flat terrain.
Returns:
FlatTerrain: A FlatTerrain instance.
"""
return FlatTerrain(_height=float(height))
def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
"""
Compute the height of the terrain at a specific (x, y) location.
Args:
x: The x-coordinate of the location.
y: The y-coordinate of the location.
Returns:
The height of the terrain at the specified location.
"""
return jnp.array(self._height, dtype=float)
def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
"""
Compute the normal vector of the terrain at a specific (x, y) location.
Args:
x: The x-coordinate of the location.
y: The y-coordinate of the location.
Returns:
The normal vector of the terrain surface at the specified location.
"""
return jnp.array([0.0, 0.0, 1.0], dtype=float)
def __hash__(self) -> int:
return hash(self._height)
def __eq__(self, other: FlatTerrain) -> bool:
if not isinstance(other, FlatTerrain):
return False
return self._height == other._height
@jax_dataclasses.pytree_dataclass
class PlaneTerrain(FlatTerrain):
"""
Represents a terrain model with a flat surface defined by a normal vector.
"""
_normal: tuple[float, float, float] = jax_dataclasses.field(
default=(0.0, 0.0, 1.0), kw_only=True
)
@staticmethod
def build(height: jtp.FloatLike = 0.0, *, normal: jtp.VectorLike) -> PlaneTerrain:
"""
Create a PlaneTerrain instance with a specified plane normal vector.
Args:
normal: The normal vector of the terrain plane.
height: The height of the plane over the origin.
Returns:
PlaneTerrain: A PlaneTerrain instance.
"""
normal = jnp.array(normal, dtype=float)
height = jnp.array(height, dtype=float)
if normal.shape != (3,):
msg = "Expected a 3D vector for the plane normal, got '{}'."
raise ValueError(msg.format(normal.shape))
# Make sure that the plane normal is a unit vector.
normal = normal / jnp.linalg.norm(normal)
return PlaneTerrain(
_height=height.item(),
_normal=tuple(normal.tolist()),
)
def normal(
self, x: jtp.FloatLike | None = None, y: jtp.FloatLike | None = None
) -> jtp.Vector:
"""
Compute the normal vector of the terrain at a specific (x, y) location.
Args:
x: The x-coordinate of the location.
y: The y-coordinate of the location.
Returns:
The normal vector of the terrain surface at the specified location.
"""
return jnp.array(self._normal, dtype=float)
def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
"""
Compute the height of the terrain at a specific (x, y) location on a plane.
Args:
x: The x-coordinate of the location.
y: The y-coordinate of the location.
Returns:
The height of the terrain at the specified location on the plane.
"""
# Equation of the plane: A x + B y + C z + D = 0
# Normal vector coordinates: (A, B, C)
# The height over the origin: -D/C
# Get the plane equation coefficients from the terrain normal.
A, B, C = self._normal
exceptions.raise_value_error_if(
condition=jnp.allclose(C, 0.0),
msg="The z component of the normal cannot be zero.",
)
# Compute the final coefficient D considering the terrain height.
D = -C * self._height
# Invert the plane equation to get the height at the given (x, y) coordinates.
return jnp.array(-(A * x + B * y + D) / C).astype(float)
def __hash__(self) -> int:
from jaxsim.utils.wrappers import HashedNumpyArray
return hash(
(
hash(self._height),
HashedNumpyArray.hash_of_array(
array=np.array(self._normal, dtype=float)
),
)
)
def __eq__(self, other: PlaneTerrain) -> bool:
if not isinstance(other, PlaneTerrain):
return False
if not (
np.allclose(self._height, other._height)
and np.allclose(
np.array(self._normal, dtype=float),
np.array(other._normal, dtype=float),
)
):
return False
return True
================================================
FILE: src/jaxsim/typing.py
================================================
from collections.abc import Hashable
from typing import Any, TypeVar
import jax
# =========
# JAX types
# =========
Array = jax.Array
Scalar = Array
Vector = Array
Matrix = Array
Int = Scalar
Bool = Scalar
Float = Scalar
PyTree: object = (
dict[Hashable, TypeVar("PyTree")]
| list[TypeVar("PyTree")]
| tuple[TypeVar("PyTree")]
| jax.Array
| Any
| None
)
# =======================
# Mixed JAX / NumPy types
# =======================
ArrayLike = jax.typing.ArrayLike | tuple
ScalarLike = int | float | Scalar | ArrayLike
VectorLike = Vector | ArrayLike | tuple
MatrixLike = Matrix | ArrayLike
IntLike = int | Int | jax.typing.ArrayLike
BoolLike = bool | Bool | jax.typing.ArrayLike
FloatLike = float | Float | jax.typing.ArrayLike
================================================
FILE: src/jaxsim/utils/__init__.py
================================================
from jax_dataclasses._copy_and_mutate import _Mutability as Mutability
from .jaxsim_dataclass import JaxsimDataclass
from .tracing import not_tracing, tracing
from .wrappers import HashedNumpyArray, HashlessObject
================================================
FILE: src/jaxsim/utils/jaxsim_dataclass.py
================================================
import abc
import contextlib
import dataclasses
import functools
from collections.abc import Callable, Iterator, Sequence
from typing import Any, ClassVar
import jax.flatten_util
import jax_dataclasses
import jaxsim.typing as jtp
from . import Mutability
try:
from typing import Self
except ImportError:
from typing_extensions import Self
@jax_dataclasses.pytree_dataclass
class JaxsimDataclass(abc.ABC):
"""Class extending `jax_dataclasses.pytree_dataclass` instances with utilities."""
# This attribute is set by jax_dataclasses
__mutability__: ClassVar[Mutability] = Mutability.FROZEN
@contextlib.contextmanager
def editable(self: Self, validate: bool = True) -> Iterator[Self]:
"""
Context manager to operate on a mutable copy of the object.
Args:
validate: Whether to validate the output PyTree upon exiting the context.
Yields:
A mutable copy of the object.
Note:
This context manager is useful to operate on an r/w copy of a PyTree making
sure that the output object does not trigger JIT recompilations.
"""
mutability = (
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
)
with self.copy().mutable_context(mutability=mutability) as obj:
yield obj
@contextlib.contextmanager
def mutable_context(
self: Self,
mutability: Mutability = Mutability.MUTABLE,
restore_after_exception: bool = True,
) -> Iterator[Self]:
"""
Context manager to temporarily change the mutability of the object.
Args:
mutability: The mutability to set.
restore_after_exception:
Whether to restore the original object in case of an exception
occurring within the context.
Yields:
The object with the new mutability.
Note:
This context manager is useful to operate in place on a PyTree without
the need to make a copy while optionally keeping active the checks on
the PyTree structure, shapes, and dtypes.
"""
if restore_after_exception:
self_copy = self.copy()
original_mutability = self.mutability()
original_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self)
original_shapes = JaxsimDataclass.get_leaf_shapes(tree=self)
original_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self)
original_structure = jax.tree.structure(tree=self)
def restore_self() -> None:
self.set_mutability(mutability=Mutability.MUTABLE_NO_VALIDATION)
for f in dataclasses.fields(self_copy):
setattr(self, f.name, getattr(self_copy, f.name))
try:
self.set_mutability(mutability=mutability)
yield self
if mutability is not Mutability.MUTABLE_NO_VALIDATION:
new_structure = jax.tree.structure(tree=self)
if original_structure != new_structure:
msg = "Pytree structure has changed from {} to {}"
raise ValueError(msg.format(original_structure, new_structure))
new_shapes = JaxsimDataclass.get_leaf_shapes(tree=self)
if original_shapes != new_shapes:
msg = "Leaves shapes have changed from {} to {}"
raise ValueError(msg.format(original_shapes, new_shapes))
new_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self)
if original_dtypes != new_dtypes:
msg = "Leaves dtypes have changed from {} to {}"
raise ValueError(msg.format(original_dtypes, new_dtypes))
new_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self)
if original_weak_types != new_weak_types:
msg = "Leaves weak types have changed from {} to {}"
raise ValueError(msg.format(original_weak_types, new_weak_types))
except Exception as e:
if restore_after_exception:
restore_self()
self.set_mutability(original_mutability)
raise e
finally:
self.set_mutability(original_mutability)
@staticmethod
def get_leaf_shapes(tree: jtp.PyTree) -> tuple[tuple[int, ...] | None]:
"""
Get the leaf shapes of a PyTree.
Args:
tree: The PyTree to consider.
Returns:
A tuple containing the leaf shapes of the PyTree or `None` is the leaf is
not a numpy-like array.
"""
return tuple(
map(
lambda leaf: getattr(leaf, "shape", None),
jax.tree.leaves(tree),
)
)
@staticmethod
def get_leaf_dtypes(tree: jtp.PyTree) -> tuple:
"""
Get the leaf dtypes of a PyTree.
Args:
tree: The PyTree to consider.
Returns:
A tuple containing the leaf dtypes of the PyTree or `None` is the leaf is
not a numpy-like array.
"""
return tuple(
map(
lambda leaf: getattr(leaf, "dtype", None),
jax.tree.leaves(tree),
)
)
@staticmethod
def get_leaf_weak_types(tree: jtp.PyTree) -> tuple[bool, ...]:
"""
Get the leaf weak types of a PyTree.
Args:
tree: The PyTree to consider.
Returns:
A tuple marking whether the leaf contains a JAX array with weak type.
"""
return tuple(
map(
lambda leaf: getattr(leaf, "weak_type", None),
jax.tree.leaves(tree),
)
)
@staticmethod
def check_compatibility(*trees: Sequence[Any]) -> None:
"""
Check whether the PyTrees are compatible in structure, shape, and dtype.
Args:
*trees: The PyTrees to compare.
Raises:
ValueError: If the PyTrees have incompatible structures, shapes, or dtypes.
"""
target_structure = jax.tree.structure(trees[0])
compatible_structure = functools.reduce(
lambda compatible, tree: compatible
and jax.tree.structure(tree) == target_structure,
trees[1:],
True,
)
if not compatible_structure:
raise ValueError(
f"Pytrees have incompatible structures.\n"
f"Original: {', '.join(map(str, [jax.tree.structure(tree) for tree in trees[1:]]))}\n"
f"Target: {target_structure}"
)
target_shapes = JaxsimDataclass.get_leaf_shapes(trees[0])
compatible_shapes = functools.reduce(
lambda compatible, tree: compatible
and JaxsimDataclass.get_leaf_shapes(tree) == target_shapes,
trees[1:],
True,
)
if not compatible_shapes:
raise ValueError("Pytrees have incompatible shapes.")
target_dtypes = JaxsimDataclass.get_leaf_dtypes(trees[0])
compatible_dtypes = functools.reduce(
lambda compatible, tree: compatible
and JaxsimDataclass.get_leaf_dtypes(tree) == target_dtypes,
trees[1:],
True,
)
if not compatible_dtypes:
raise ValueError("Pytrees have incompatible dtypes.")
def is_mutable(self, validate: bool = False) -> bool:
"""
Check whether the object is mutable.
Args:
validate: Additionally checks if the object also has validation enabled.
Returns:
True if the object is mutable, False otherwise.
"""
return (
self.__mutability__ is Mutability.MUTABLE
if validate
else self.__mutability__ is Mutability.MUTABLE_NO_VALIDATION
)
def mutability(self) -> Mutability:
"""
Get the mutability type of the object.
Returns:
The mutability type of the object.
"""
return self.__mutability__
def set_mutability(self, mutability: Mutability) -> None:
"""
Set the mutability of the object in-place.
Args:
mutability: The desired mutability type.
"""
jax_dataclasses._copy_and_mutate._mark_mutable(
self, mutable=mutability, visited=set()
)
def mutable(self: Self, mutable: bool = True, validate: bool = False) -> Self:
"""
Return a mutable reference of the object.
Args:
mutable: Whether to make the object mutable.
validate: Whether to enable validation on the object.
Returns:
A mutable reference of the object.
"""
if mutable:
mutability = (
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
)
else:
mutability = Mutability.FROZEN
self.set_mutability(mutability=mutability)
return self
def copy(self: Self) -> Self:
"""
Return a copy of the object.
Returns:
A copy of the object.
"""
# Make a copy calling tree_map.
obj = jax.tree.map(lambda leaf: leaf, self)
# Make sure that the copied object and all the copied leaves have the same
# mutability of the original object.
obj.set_mutability(mutability=self.mutability())
return obj
def replace(self: Self, validate: bool = True, **kwargs) -> Self:
"""
Return a new object replacing in-place the specified fields with new values.
Args:
validate: Whether to validate that the new fields do not alter the PyTree.
**kwargs: The fields to replace.
Returns:
A reference of the object with the specified fields replaced.
"""
# Use the dataclasses replace method.
obj = dataclasses.replace(self, **kwargs)
if validate:
JaxsimDataclass.check_compatibility(self, obj)
# Make sure that all the new leaves have the same mutability of the object.
obj.set_mutability(mutability=self.mutability())
return obj
def flatten(self) -> jtp.Vector:
"""
Flatten the object into a 1D vector.
Returns:
A 1D vector containing the flattened object.
"""
return self.flatten_fn()(self)
@classmethod
def flatten_fn(cls: type[Self]) -> Callable[[Self], jtp.Vector]:
"""
Return a function to flatten the object into a 1D vector.
Returns:
A function to flatten the object into a 1D vector.
"""
return lambda pytree: jax.flatten_util.ravel_pytree(pytree)[0]
def unflatten_fn(self: Self) -> Callable[[jtp.Vector], Self]:
"""
Return a function to unflatten a 1D vector into the object.
Returns:
A function to unflatten a 1D vector into the object.
Notes:
Due to JAX internals, the function to unflatten a PyTree needs to be
created from an existing instance of the PyTree.
"""
return jax.flatten_util.ravel_pytree(self)[1]
================================================
FILE: src/jaxsim/utils/tracing.py
================================================
from typing import Any
import jax._src.core
import jax.flatten_util
import jax.interpreters.partial_eval
def tracing(var: Any) -> bool | jax.Array:
"""Return True if the variable is being traced by JAX, False otherwise."""
return isinstance(
var, jax._src.core.Tracer | jax.interpreters.partial_eval.DynamicJaxprTracer
)
def not_tracing(var: Any) -> bool | jax.Array:
"""Return True if the variable is not being traced by JAX, False otherwise."""
return True if tracing(var) is False else False
================================================
FILE: src/jaxsim/utils/wrappers.py
================================================
from __future__ import annotations
import dataclasses
from collections.abc import Callable
from typing import Generic, TypeVar
import jax
import jax_dataclasses
import numpy as np
import numpy.typing as npt
T = TypeVar("T")
@dataclasses.dataclass
class HashlessObject(Generic[T]):
"""
A class that wraps an object and makes it hashless.
This is useful for creating particular JAX pytrees.
For example, to create a pytree with a static leaf that is ignored
by JAX when it compares two instances to trigger a JIT recompilation.
"""
obj: T
def get(self: HashlessObject[T]) -> T:
"""
Get the wrapped object.
"""
return self.obj
def __hash__(self) -> int:
return 0
def __eq__(self, other: HashlessObject[T]) -> bool:
if not isinstance(other, HashlessObject) and isinstance(
other.get(), type(self.get())
):
return False
return hash(self) == hash(other)
@dataclasses.dataclass
class CustomHashedObject(Generic[T]):
"""
A class that wraps an object and computes its hash with a custom hash function.
"""
obj: T
hash_function: Callable[[T], int] = hash
def get(self: CustomHashedObject[T]) -> T:
"""
Get the wrapped object.
"""
return self.obj
def __hash__(self) -> int:
return self.hash_function(self.obj)
def __eq__(self, other: CustomHashedObject[T]) -> bool:
if not isinstance(other, CustomHashedObject) and isinstance(
other.get(), type(self.get())
):
return False
return hash(self) == hash(other)
@jax_dataclasses.pytree_dataclass
class HashedNumpyArray:
"""
A class that wraps a numpy array and makes it hashable.
This is useful for creating particular JAX pytrees.
For example, to create a pytree with a plain NumPy or JAX NumPy array as static leaf.
Note:
Calculating with the wrapper class the hash of a very large array can be
very expensive. If the array is large and only the equality operator is needed,
set `large_array=True` to use a faster comparison method.
"""
array: jax.Array | npt.NDArray
precision: float | None = dataclasses.field(
default=1e-9, repr=False, compare=False, hash=False
)
large_array: jax_dataclasses.Static[bool] = dataclasses.field(
default=False, repr=False, compare=False, hash=False
)
def get(self) -> jax.Array | npt.NDArray:
"""
Get the wrapped array.
"""
return self.array
def __hash__(self) -> int:
return HashedNumpyArray.hash_of_array(
array=self.array, precision=self.precision
)
def __eq__(self, other: HashedNumpyArray) -> bool:
if not isinstance(other, HashedNumpyArray):
return False
if self.large_array:
return np.allclose(
self.array,
other.array,
**(dict(atol=self.precision) if self.precision is not None else {}),
)
return hash(self) == hash(other)
@staticmethod
def hash_of_array(
array: jax.Array | npt.NDArray, precision: float | None = 1e-9
) -> int:
"""
Calculate the hash of a NumPy array.
Args:
array: The array to hash.
precision: Optionally limit the precision over which the hash is computed.
Returns:
The hash of the array.
"""
array = np.array(array).flatten()
array = np.where(array == np.nan, hash(np.nan), array)
array = np.where(array == np.inf, hash(np.inf), array)
array = np.where(array == -np.inf, hash(-np.inf), array)
if precision is not None:
integer1 = (array * precision).astype(int)
integer2 = (array - integer1 / precision).astype(int)
decimal_array = ((array - integer1 * 1e9 - integer2) / precision).astype(
int
)
array = np.hstack([integer1, integer2, decimal_array]).astype(int)
return hash(tuple(array.tolist()))
================================================
FILE: tests/__init__.py
================================================
================================================
FILE: tests/assets/4_bar_opened.urdf
================================================
================================================
FILE: tests/assets/cube.stl
================================================
solid model
facet normal 0.0 0.0 -1.0
outer loop
vertex 20.0 0.0 0.0
vertex 0.0 -20.0 0.0
vertex 0.0 0.0 0.0
endloop
endfacet
facet normal 0.0 0.0 -1.0
outer loop
vertex 0.0 -20.0 0.0
vertex 20.0 0.0 0.0
vertex 20.0 -20.0 0.0
endloop
endfacet
facet normal -0.0 -1.0 -0.0
outer loop
vertex 20.0 -20.0 20.0
vertex 0.0 -20.0 0.0
vertex 20.0 -20.0 0.0
endloop
endfacet
facet normal -0.0 -1.0 -0.0
outer loop
vertex 0.0 -20.0 0.0
vertex 20.0 -20.0 20.0
vertex 0.0 -20.0 20.0
endloop
endfacet
facet normal 1.0 0.0 0.0
outer loop
vertex 20.0 0.0 0.0
vertex 20.0 -20.0 20.0
vertex 20.0 -20.0 0.0
endloop
endfacet
facet normal 1.0 0.0 0.0
outer loop
vertex 20.0 -20.0 20.0
vertex 20.0 0.0 0.0
vertex 20.0 0.0 20.0
endloop
endfacet
facet normal -0.0 -0.0 1.0
outer loop
vertex 20.0 -20.0 20.0
vertex 0.0 0.0 20.0
vertex 0.0 -20.0 20.0
endloop
endfacet
facet normal -0.0 -0.0 1.0
outer loop
vertex 0.0 0.0 20.0
vertex 20.0 -20.0 20.0
vertex 20.0 0.0 20.0
endloop
endfacet
facet normal -1.0 0.0 0.0
outer loop
vertex 0.0 0.0 20.0
vertex 0.0 -20.0 0.0
vertex 0.0 -20.0 20.0
endloop
endfacet
facet normal -1.0 0.0 0.0
outer loop
vertex 0.0 -20.0 0.0
vertex 0.0 0.0 20.0
vertex 0.0 0.0 0.0
endloop
endfacet
facet normal -0.0 1.0 0.0
outer loop
vertex 0.0 0.0 20.0
vertex 20.0 0.0 0.0
vertex 0.0 0.0 0.0
endloop
endfacet
facet normal -0.0 1.0 0.0
outer loop
vertex 20.0 0.0 0.0
vertex 0.0 0.0 20.0
vertex 20.0 0.0 20.0
endloop
endfacet
endsolid model
================================================
FILE: tests/assets/double_pendulum.sdf
================================================
worldbase_link1 0 0-551001000.000.00 0 0 0 0 01001001010 0 1 0 0 00.20 0.20 2.150 0 1 0 0 00.20 0.20 2.150.20 0 2 -3.1415 0 0base_linkright_link1 0 0-1001001001001.000.00 0 0 0 0 000 0 0.5 0 0 011.0001.001.00 0 0.5 0 0 00.20 0.20 1.0-0.20 0 2 -3.1415 0 0base_linkleft_link1 0 0-1001001001001.000.00 0 0 0 0 000.0 0 0.5 0 0 011.0001.001.00.0 0 0.5 0 0 00.20 0.20 1.00.20 0 1 0 0 0-0.20 0 1 0 0 0 -0.2 0 1 3.14 0 0 0.2 0 1 3.14 0 0
================================================
FILE: tests/assets/mixed_shapes_robot.urdf
================================================
================================================
FILE: tests/assets/test_cube.urdf
================================================
================================================
FILE: tests/conftest.py
================================================
import os
os.environ["JAXSIM_ENABLE_EXCEPTIONS"] = "1"
import pathlib
import subprocess
import jax
import numpy as np
import pytest
import rod
import rod.urdf.exporter
import jaxsim
import jaxsim.api as js
from jaxsim.api.model import IntegratorType
def pytest_addoption(parser):
parser.addoption(
"--gpu-only",
action="store_true",
default=False,
help="Run tests only if GPU is available and utilized",
)
parser.addoption(
"--batch-size",
action="store",
default="None",
help="Batch size for vectorized benchmarks (only applies to benchmark tests)",
)
def pytest_generate_tests(metafunc):
if (
"batch_size" in metafunc.fixturenames
and (batch_size := metafunc.config.getoption("--batch-size")) != "None"
):
metafunc.parametrize("batch_size", [1, int(batch_size)])
def check_gpu_usage():
# Set environment variable to prioritize GPU.
os.environ["JAX_PLATFORM_NAME"] = "gpu"
# Run a simple JAX operation
x = jax.device_put(jax.numpy.ones((512, 512)))
y = jax.device_put(jax.numpy.ones((512, 512)))
_ = jax.numpy.dot(x, y).block_until_ready()
# Check GPU memory usage with nvidia-smi.
result = subprocess.run(
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,noheader"],
capture_output=True,
text=True,
)
if result.returncode != 0:
pytest.exit(
"Failed to query GPU usage. Ensure nvidia-smi is installed and accessible."
)
gpu_memory_usage = [
int(line.strip().split()[0]) for line in result.stdout.splitlines()
]
if all(usage == 0 for usage in gpu_memory_usage):
pytest.exit(
"GPU is available but not utilized during computations. Check your JAX installation."
)
def pytest_configure(config) -> None:
"""Pytest configuration hook."""
# This is a global variable that is updated by the `prng_key` fixture.
pytest.prng_key = jax.random.PRNGKey(
seed=int(os.environ.get("JAXSIM_TEST_SEED", 0))
)
# Check if GPU is available and utilized.
if config.getoption("--gpu-only"):
devices = jax.devices()
if not any(device.platform == "gpu" for device in devices):
pytest.exit("No GPU devices found. Check your JAX installation.")
# Ensure GPU is being used during computation
check_gpu_usage()
def load_model_from_file(file_path: pathlib.Path, is_urdf=False) -> rod.Sdf:
"""
Load an SDF or URDF model from a file.
Args:
file_path: The path to the model file.
is_urdf: Whether the file is in URDF or SDF format.
Returns:
The corresponding rod model.
"""
return rod.Sdf.load(file_path, is_urdf=is_urdf)
# ================
# Generic fixtures
# ================
@pytest.fixture(scope="function")
def prng_key() -> jax.Array:
"""
Fixture to generate a new PRNG key for each test function.
Returns:
The new PRNG key passed to the test.
Note:
This fixture operates on a global variable initialized in the
`pytest_configure` hook.
"""
pytest.prng_key, subkey = jax.random.split(pytest.prng_key, num=2)
return subkey
@pytest.fixture(
scope="function",
params=[
pytest.param(jaxsim.VelRepr.Inertial, id="inertial"),
pytest.param(jaxsim.VelRepr.Body, id="body"),
pytest.param(jaxsim.VelRepr.Mixed, id="mixed"),
],
)
def velocity_representation(request) -> jaxsim.VelRepr:
"""
Parametrized fixture providing all supported velocity representations.
Returns:
A velocity representation.
"""
return request.param
@pytest.fixture(
scope="function",
params=[
pytest.param(IntegratorType.SemiImplicitEuler, id="semi_implicit_euler"),
pytest.param(IntegratorType.RungeKutta4, id="runge_kutta_4"),
pytest.param(IntegratorType.RungeKutta4Fast, id="runge_kutta_4_fast"),
],
)
def integrator(request) -> str:
"""
Fixture providing the integrator to use in the simulation.
Returns:
The integrator to use in the simulation.
"""
return request.param
@pytest.fixture(scope="session")
def batch_size(request) -> int:
"""
Fixture providing the batch size for vectorized benchmarks.
Returns:
The batch size for vectorized benchmarks.
"""
return 1
# ================================
# Fixtures providing JaxSim models
# ================================
# All the fixtures in this section must have "session" scope.
# In this way, the models are generated only once and shared among all the tests.
# This is not a fixture.
def build_jaxsim_model(
model_description: str | pathlib.Path | rod.Model,
) -> js.model.JaxSimModel:
"""
Build a JaxSim model from a model description.
Args:
model_description: A model description provided by any fixture provider.
Returns:
A JaxSim model built from the provided description.
"""
# Build the JaxSim model.
model = js.model.JaxSimModel.build_from_model_description(
model_description=model_description,
)
return model
@pytest.fixture(scope="session")
def jaxsim_model_box() -> js.model.JaxSimModel:
"""
Fixture providing the JaxSim model of a box.
Returns:
The JaxSim model of a box.
"""
import rod.builder.primitives
import rod.urdf.exporter
# Create on-the-fly a ROD model of a box.
rod_model = (
rod.builder.primitives.BoxBuilder(x=0.3, y=0.2, z=0.1, mass=1.0, name="box")
.build_model()
.add_link(name="box_link")
.add_inertial()
.add_visual()
.add_collision()
.build()
)
rod_model.add_frame(
rod.Frame(
name="box_frame",
attached_to="box_link",
pose=rod.Pose(relative_to="box_link", pose=[1, 1, 1, 0.5, 0.4, 0.3]),
)
)
# Export the URDF string.
urdf_string = rod.urdf.exporter.UrdfExporter(
pretty=True, gazebo_preserve_fixed_joints=True
).to_urdf_string(sdf=rod_model)
return build_jaxsim_model(model_description=urdf_string)
@pytest.fixture(scope="session")
def jaxsim_model_sphere() -> js.model.JaxSimModel:
"""
Fixture providing the JaxSim model of a sphere.
Returns:
The JaxSim model of a sphere.
"""
import rod.builder.primitives
import rod.urdf.exporter
# Create on-the-fly a ROD model of a sphere.
rod_model = (
rod.builder.primitives.SphereBuilder(radius=0.1, mass=1.0, name="sphere")
.build_model()
.add_link()
.add_inertial()
.add_visual()
.add_collision()
.build()
)
# Export the URDF string.
urdf_string = rod.urdf.exporter.UrdfExporter(pretty=True).to_urdf_string(
sdf=rod_model
)
return build_jaxsim_model(model_description=urdf_string)
@pytest.fixture(scope="session")
def ergocub_model_description_path() -> pathlib.Path:
"""
Fixture providing the path to the URDF model description of the ErgoCub robot.
Returns:
The path to the URDF model description of the ErgoCub robot.
"""
try:
os.environ["ROBOT_DESCRIPTION_COMMIT"] = "v0.7.7"
import robot_descriptions.ergocub_description
finally:
_ = os.environ.pop("ROBOT_DESCRIPTION_COMMIT", None)
model_urdf_path = pathlib.Path(
robot_descriptions.ergocub_description.URDF_PATH.replace(
"ergoCubSN002", "ergoCubSN001"
)
)
return model_urdf_path
@pytest.fixture(scope="session")
def jaxsim_model_ergocub(
ergocub_model_description_path: pathlib.Path,
) -> js.model.JaxSimModel:
"""
Fixture providing the JaxSim model of the ErgoCub robot.
Returns:
The JaxSim model of the ErgoCub robot.
"""
return build_jaxsim_model(model_description=ergocub_model_description_path)
@pytest.fixture(scope="session")
def jaxsim_model_ergocub_reduced(jaxsim_model_ergocub) -> js.model.JaxSimModel:
"""
Fixture providing the JaxSim model of the ErgoCub robot with only locomotion joints.
Returns:
The JaxSim model of the ErgoCub robot with only locomotion joints.
"""
model_full = jaxsim_model_ergocub
# Get the names of the joints to keep.
reduced_joints = tuple(
j
for j in model_full.joint_names()
if "camera" not in j
# Remove head and hands.
and "neck" not in j
and "wrist" not in j
and "thumb" not in j
and "index" not in j
and "middle" not in j
and "ring" not in j
and "pinkie" not in j
# Remove upper body.
and "torso" not in j and "elbow" not in j and "shoulder" not in j
)
model = js.model.reduce(model=model_full, considered_joints=reduced_joints)
return model
@pytest.fixture(scope="session")
def jaxsim_model_ur10() -> js.model.JaxSimModel:
"""
Fixture providing the JaxSim model of the UR10 robot.
Returns:
The JaxSim model of the UR10 robot.
"""
import robot_descriptions.ur10_description
model_urdf_path = pathlib.Path(robot_descriptions.ur10_description.URDF_PATH)
return build_jaxsim_model(model_description=model_urdf_path)
@pytest.fixture(scope="session")
def jaxsim_model_single_pendulum() -> js.model.JaxSimModel:
"""
Fixture providing the JaxSim model of a single pendulum.
Returns:
The JaxSim model of a single pendulum.
"""
import rod.builder.primitives
base_height = 2.15
upper_height = 1.0
# ===================
# Create the builders
# ===================
base_builder = rod.builder.primitives.BoxBuilder(
name="base",
mass=1.0,
x=0.15,
y=0.15,
z=base_height,
)
upper_builder = rod.builder.primitives.BoxBuilder(
name="upper",
mass=0.5,
x=0.15,
y=0.15,
z=upper_height,
)
# =================
# Create the joints
# =================
fixed = rod.Joint(
name="fixed_joint",
type="fixed",
parent="world",
child=base_builder.name,
)
pivot = rod.Joint(
name="upper_joint",
type="continuous",
parent=base_builder.name,
child=upper_builder.name,
axis=rod.Axis(
xyz=rod.Xyz([1, 0, 0]),
),
)
# ================
# Create the links
# ================
base = (
base_builder.build_link(
name=base_builder.name,
pose=rod.builder.primitives.PrimitiveBuilder.build_pose(
pos=np.array([0, 0, base_height / 2])
),
)
.add_inertial()
.add_visual()
.add_collision()
.build()
)
upper_pose = rod.builder.primitives.PrimitiveBuilder.build_pose(
pos=np.array([0, 0, upper_height / 2])
)
upper = (
upper_builder.build_link(
name=upper_builder.name,
pose=rod.builder.primitives.PrimitiveBuilder.build_pose(
relative_to=base.name, pos=np.array([0, 0, upper_height])
),
)
.add_inertial(pose=upper_pose)
.add_visual(pose=upper_pose)
.add_collision(pose=upper_pose)
.build()
)
rod_model = rod.Sdf(
version="1.10",
model=rod.Model(
name="single_pendulum",
link=[base, upper],
joint=[fixed, pivot],
),
)
rod_model.model.resolve_frames()
urdf_string = rod.urdf.exporter.UrdfExporter(pretty=True).to_urdf_string(
sdf=rod_model.models()[0]
)
model = build_jaxsim_model(model_description=urdf_string)
return model
@pytest.fixture(scope="session")
def jaxsim_model_garpez() -> js.model.JaxSimModel:
"""Fixture to create the original (unscaled) Garpez model."""
rod_model = create_scalable_garpez_model()
urdf_string = rod.urdf.exporter.UrdfExporter(pretty=True).to_urdf_string(
sdf=rod_model
)
return build_jaxsim_model(model_description=urdf_string)
@pytest.fixture(scope="session")
def jaxsim_model_garpez_scaled(request) -> js.model.JaxSimModel:
"""Fixture to create the scaled version of the Garpez model."""
# Get the link scales from the request.
link1_scale = request.param.get("link1_scale", 1.0)
link2_scale = request.param.get("link2_scale", 1.0)
link3_scale = request.param.get("link3_scale", 1.0)
link4_scale = request.param.get("link4_scale", 1.0)
rod_model = create_scalable_garpez_model(
link1_scale=link1_scale,
link2_scale=link2_scale,
link3_scale=link3_scale,
link4_scale=link4_scale,
)
urdf_string = rod.urdf.exporter.UrdfExporter(pretty=True).to_urdf_string(
sdf=rod_model
)
return build_jaxsim_model(model_description=urdf_string)
def create_scalable_garpez_model(
link1_scale: float = 1.0,
link2_scale: float = 1.0,
link3_scale: float = 1.0,
link4_scale: float = 1.0,
) -> rod.Model:
"""
Build a scalable rod model to test parameterization and scaling.
Args:
link1_scale: Scale factor for link 1.
link2_scale: Scale factor for link 2.
link3_scale: Scale factor for link 3.
link4_scale: Scale factor for link 4.
Returns:
A rod model with the specified link scales.
Note:
The model is built assuming a constant link density, hence scaling the link will also have an impact on the link mass.
"""
import numpy as np
from rod.builder import primitives
# ========================
# Create the link builders
# ========================
density = 1000.0 # Fixed density in kg/m^3
l1_x, l1_y, l1_z = 0.3 * link1_scale, 0.2, 0.2
l1_volume = l1_x * l1_y * l1_z
l1_mass = density * l1_volume
link1_builder = primitives.BoxBuilder(
name="link1", mass=l1_mass, x=l1_x, y=l1_y, z=l1_z
)
l2_radius = 0.1 * link2_scale
l2_volume = 4 / 3 * np.pi * l2_radius**3
l2_mass = density * l2_volume
link2_builder = primitives.SphereBuilder(
name="link2", mass=l2_mass, radius=l2_radius
)
l3_radius = 0.05
l3_length = 0.5 * link3_scale
l3_volume = np.pi * l3_radius**2 * l3_length
l3_mass = density * l3_volume
link3_builder = primitives.CylinderBuilder(
name="link3", mass=l3_mass, radius=l3_radius, length=l3_length
)
l4_x, l4_y, l4_z = 0.3 * link4_scale, 0.2, 0.1
l4_volume = l4_x * l4_y * l4_z
l4_mass = density * l4_volume
link4_builder = primitives.BoxBuilder(
name="link4", mass=l4_mass, x=l4_x, y=l4_y, z=l4_z
)
# =================
# Create the joints
# =================
link1_to_link2 = rod.Joint(
name="link1_to_link2",
type="revolute",
parent=link1_builder.name,
child=link2_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(
relative_to=link1_builder.name,
pos=np.array([link1_builder.x, link1_builder.y / 2, link1_builder.z / 2]),
),
axis=rod.Axis(xyz=rod.Xyz(xyz=[0, 1, 0]), limit=rod.Limit()),
)
link2_to_link3 = rod.Joint(
name="link2_to_link3",
type="revolute",
parent=link2_builder.name,
child=link3_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(
relative_to=link2_builder.name,
pos=np.array([link2_builder.radius, 0, -link2_builder.radius]),
),
axis=rod.Axis(xyz=rod.Xyz(xyz=[0, 0, 1]), limit=rod.Limit()),
)
link3_to_link4 = rod.Joint(
name="link3_to_link4",
type="revolute",
parent=link3_builder.name,
child=link4_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(
relative_to=link3_builder.name,
pos=np.array([-link3_builder.radius, 0, -link3_builder.length]),
),
axis=rod.Axis(xyz=rod.Xyz(xyz=[1, 0, 0]), limit=rod.Limit()),
)
# ================
# Create the links
# ================
link1_elements_pose = primitives.PrimitiveBuilder.build_pose(
pos=np.array([link1_builder.x, link1_builder.y, link1_builder.z]) / 2
)
link1 = (
link1_builder.build_link(
name=link1_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(relative_to="__model__"),
)
.add_inertial(pose=link1_elements_pose)
.add_visual(pose=link1_elements_pose)
.add_collision(pose=link1_elements_pose)
.build()
)
link2_elements_pose = primitives.PrimitiveBuilder.build_pose(
pos=np.array([link2_builder.radius, 0, 0])
)
link2 = (
link2_builder.build_link(
name=link2_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(
relative_to=link1_to_link2.name
),
)
.add_inertial(pose=link2_elements_pose)
.add_visual(pose=link2_elements_pose)
.add_collision(pose=link2_elements_pose)
.build()
)
link3_elements_pose = primitives.PrimitiveBuilder.build_pose(
pos=np.array([0, 0, -link3_builder.length / 2])
)
link3 = (
link3_builder.build_link(
name=link3_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(
relative_to=link2_to_link3.name
),
)
.add_inertial(pose=link3_elements_pose)
.add_visual(pose=link3_elements_pose)
.add_collision(pose=link3_elements_pose)
.build()
)
link4_elements_pose = primitives.PrimitiveBuilder.build_pose(
# pos=np.array([0, 0, -link4_builder.z / 2])
pos=np.array([link4_builder.x / 2, 0, -link4_builder.z / 2])
)
link4 = (
link4_builder.build_link(
name=link4_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(
relative_to=link3_to_link4.name
),
)
.add_inertial(pose=link4_elements_pose)
.add_visual(pose=link4_elements_pose)
.add_collision(pose=link4_elements_pose)
.build()
)
# ===========
# Build model
# ===========
# Create model
rod_model = rod.Model(
name="model_demo",
canonical_link=link1.name,
link=[link1, link2, link3, link4],
joint=[link1_to_link2, link2_to_link3, link3_to_link4],
)
rod_model.switch_frame_convention(
frame_convention=rod.FrameConvention.Urdf,
explicit_frames=True,
attach_frames_to_links=True,
)
assert rod.Sdf(model=rod_model, version="1.10").serialize(validate=True)
return rod_model
def create_model_with_missing_collision() -> rod.Model:
"""
Build a rod model with a link that has a visual but no collision element.
This model is used to test the export logic when collision elements are missing.
Returns:
A rod model with one link missing a collision element.
"""
import numpy as np
from rod.builder import primitives
density = 1000.0 # Fixed density in kg/m^3
# Create link1 with both visual and collision
l1_x, l1_y, l1_z = 0.3, 0.2, 0.2
l1_volume = l1_x * l1_y * l1_z
l1_mass = density * l1_volume
link1_builder = primitives.BoxBuilder(
name="link1", mass=l1_mass, x=l1_x, y=l1_y, z=l1_z
)
# Create link2 with visual but WITHOUT collision
l2_radius = 0.1
l2_volume = 4 / 3 * np.pi * l2_radius**3
l2_mass = density * l2_volume
link2_builder = primitives.SphereBuilder(
name="link2", mass=l2_mass, radius=l2_radius
)
# Create joint
link1_to_link2 = rod.Joint(
name="link1_to_link2",
type="revolute",
parent=link1_builder.name,
child=link2_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(
relative_to=link1_builder.name,
pos=np.array([link1_builder.x, link1_builder.y / 2, link1_builder.z / 2]),
),
axis=rod.Axis(xyz=rod.Xyz(xyz=[0, 1, 0]), limit=rod.Limit()),
)
# Build link1 with visual and collision
link1_elements_pose = primitives.PrimitiveBuilder.build_pose(
pos=np.array([link1_builder.x, link1_builder.y, link1_builder.z]) / 2
)
link1 = (
link1_builder.build_link(
name=link1_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(relative_to="__model__"),
)
.add_inertial(pose=link1_elements_pose)
.add_visual(pose=link1_elements_pose)
.add_collision(pose=link1_elements_pose)
.build()
)
# Build link2 with visual but NO collision
link2_elements_pose = primitives.PrimitiveBuilder.build_pose(
pos=np.array([link2_builder.radius, 0, 0])
)
link2 = (
link2_builder.build_link(
name=link2_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(
relative_to=link1_to_link2.name
),
)
.add_inertial(pose=link2_elements_pose)
.add_visual(pose=link2_elements_pose)
# Note: NO .add_collision() call here
.build()
)
# Create model
rod_model = rod.Model(
name="model_missing_collision",
canonical_link=link1.name,
link=[link1, link2],
joint=[link1_to_link2],
)
rod_model.switch_frame_convention(
frame_convention=rod.FrameConvention.Urdf,
explicit_frames=True,
attach_frames_to_links=True,
)
assert rod.Sdf(model=rod_model, version="1.10").serialize(validate=True)
return rod_model
@pytest.fixture(scope="session")
def jaxsim_model_missing_collision() -> js.model.JaxSimModel:
"""
Fixture to create a model with a link that has a visual but no collision element.
This is used to test the export logic when collision elements are missing.
"""
rod_model = create_model_with_missing_collision()
urdf_string = rod.urdf.exporter.UrdfExporter(pretty=True).to_urdf_string(
sdf=rod_model
)
return build_jaxsim_model(model_description=urdf_string)
@pytest.fixture(scope="session")
def jaxsim_model_double_pendulum() -> js.model.JaxSimModel:
"""
Fixture providing the JaxSim model of a double pendulum.
Returns:
The JaxSim model of a double pendulum.
"""
model_path = pathlib.Path(__file__).parent / "assets" / "double_pendulum.sdf"
rod_model = load_model_from_file(model_path)
model = build_jaxsim_model(model_description=rod_model)
return model
@pytest.fixture(scope="session")
def jaxsim_model_cartpole() -> js.model.JaxSimModel:
"""
Fixture providing the JaxSim model of a cartpole.
Returns:
The JaxSim model of a cartpole.
"""
model_path = (
pathlib.Path(__file__).parent.parent / "examples" / "assets" / "cartpole.urdf"
)
rod_model = load_model_from_file(model_path, is_urdf=True)
model = build_jaxsim_model(model_description=rod_model)
return model
@pytest.fixture(scope="session")
def jaxsim_model_4_bar_linkage() -> js.model.JaxSimModel:
"""
Fixture providing the JaxSim model of a 4-bar linkage (opened configuration).
Returns:
The JaxSim model of the 4-bar linkage.
"""
model_path = pathlib.Path(__file__).parent / "assets" / "4_bar_opened.urdf"
rod_model = load_model_from_file(model_path, is_urdf=True)
model = build_jaxsim_model(model_description=rod_model)
return model
# ============================
# Collections of JaxSim models
# ============================
def get_jaxsim_model_fixture(
model_name: str, request: pytest.FixtureRequest
) -> str | pathlib.Path:
"""
Get the fixture providing the JaxSim model of a robot.
Args:
model_name: The name of the model.
request: The request object.
Returns:
The JaxSim model of the robot.
"""
match model_name:
case "box":
return request.getfixturevalue(jaxsim_model_box.__name__)
case "sphere":
return request.getfixturevalue(jaxsim_model_sphere.__name__)
case "ergocub":
return request.getfixturevalue(jaxsim_model_ergocub.__name__)
case "ergocub_reduced":
return request.getfixturevalue(jaxsim_model_ergocub_reduced.__name__)
case "ur10":
return request.getfixturevalue(jaxsim_model_ur10.__name__)
case "single_pendulum":
return request.getfixturevalue(jaxsim_model_single_pendulum.__name__)
case "garpez":
return request.getfixturevalue(jaxsim_model_garpez.__name__)
case "garpez_scaled":
return request.getfixturevalue(jaxsim_model_garpez_scaled.__name__)
case _:
raise ValueError(model_name)
@pytest.fixture(
scope="session",
params=[
"box",
"sphere",
"ur10",
"ergocub",
"ergocub_reduced",
],
)
def jaxsim_models_all(request) -> pathlib.Path | str:
"""
Fixture providing the JaxSim models of all supported robots.
"""
model_name: str = request.param
return get_jaxsim_model_fixture(model_name=model_name, request=request)
@pytest.fixture(
scope="session",
params=[
"box",
"ur10",
"ergocub_reduced",
],
)
def jaxsim_models_types(request) -> pathlib.Path | str:
"""
Fixture providing JaxSim models of all types of supported robots.
Note:
At the moment, most of our tests use this fixture. It provides:
- A robot with no joints.
- A fixed-base robot.
- A floating-base robot.
"""
model_name: str = request.param
return get_jaxsim_model_fixture(model_name=model_name, request=request)
@pytest.fixture(
scope="session",
params=[
"box",
"sphere",
],
)
def jaxsim_models_no_joints(request) -> pathlib.Path | str:
"""
Fixture providing JaxSim models of robots with no joints.
"""
model_name: str = request.param
return get_jaxsim_model_fixture(model_name=model_name, request=request)
@pytest.fixture(
scope="session",
params=[
"ergocub",
"ergocub_reduced",
],
)
def jaxsim_models_floating_base(request) -> pathlib.Path | str:
"""
Fixture providing JaxSim models of floating-base robots.
"""
model_name: str = request.param
return get_jaxsim_model_fixture(model_name=model_name, request=request)
@pytest.fixture(
scope="session",
params=[
"ur10",
],
)
def jaxsim_models_fixed_base(request) -> pathlib.Path | str:
"""
Fixture providing JaxSim models of fixed-base robots.
"""
model_name: str = request.param
return get_jaxsim_model_fixture(model_name=model_name, request=request)
@pytest.fixture(scope="function")
def set_jax_32bit(monkeypatch):
"""
Fixture that temporarily sets JAX precision to 32-bit for the duration of the test.
"""
del globals()["jaxsim"]
del globals()["js"]
# Temporarily disable x64
monkeypatch.setenv("JAX_ENABLE_X64", "0")
@pytest.fixture(scope="function")
def jaxsim_model_box_32bit(set_jax_32bit, request) -> js.model.JaxSimModel:
"""
Fixture providing the JaxSim model of a box with 32-bit precision.
Returns:
The JaxSim model of a box with 32-bit precision.
"""
return get_jaxsim_model_fixture(model_name="box", request=request)
================================================
FILE: tests/test_actuation.py
================================================
import jax.numpy as jnp
from numpy.testing import assert_array_less
import jaxsim.api as js
import jaxsim.rbda
from jaxsim import VelRepr
from .utils import assert_allclose
def test_tn_curve(jaxsim_model_single_pendulum: js.model.JaxSimModel):
model = jaxsim_model_single_pendulum
new_act_params = jaxsim.rbda.actuation.ActuationParams()
with new_act_params.editable(validate=False) as new_act_params:
new_act_params.torque_max = 10
new_act_params.omega_th = 1
new_act_params.omega_max = 2
with model.editable(validate=False) as model:
model.actuation_params = new_act_params
data = js.data.JaxSimModelData.build(
model=model,
velocity_representation=VelRepr.Inertial,
)
new_joint_velocities = 1.5 * jnp.ones(model.dofs())
joint_torques_0 = 30 * jnp.ones(model.dofs())
data_0 = data.replace(model=model, joint_velocities=new_joint_velocities)
τ_total = js.actuation_model.compute_resultant_torques(
model, data_0, joint_force_references=joint_torques_0
)
assert_array_less(τ_total, joint_torques_0)
new_joint_velocities = 2.5 * jnp.ones(model.dofs())
joint_torques_0 = 30 * jnp.ones(model.dofs())
data_0 = data.replace(model=model, joint_velocities=new_joint_velocities)
τ_total = js.actuation_model.compute_resultant_torques(
model, data_0, joint_force_references=joint_torques_0
)
assert_allclose(τ_total, 0.0)
================================================
FILE: tests/test_api_com.py
================================================
import jax
import jaxsim.api as js
from jaxsim import VelRepr
from . import utils
from .utils import assert_allclose
def test_com_properties(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)
# =====
# Tests
# =====
p_com_idt = kin_dyn.com_position()
p_com_js = js.com.com_position(model=model, data=data)
assert_allclose(p_com_idt, p_com_js)
J_Gh_idt = kin_dyn.centroidal_momentum_jacobian()
J_Gh_js = js.com.centroidal_momentum_jacobian(model=model, data=data)
assert_allclose(J_Gh_idt, J_Gh_js)
h_com_idt = kin_dyn.centroidal_momentum()
h_com_js = js.com.centroidal_momentum(model=model, data=data)
assert_allclose(h_com_idt, h_com_js)
M_com_locked_idt = kin_dyn.locked_centroidal_spatial_inertia()
M_com_locked_js = js.com.locked_centroidal_spatial_inertia(model=model, data=data)
assert_allclose(M_com_locked_idt, M_com_locked_js)
J_avg_com_idt = kin_dyn.average_centroidal_velocity_jacobian()
J_avg_com_js = js.com.average_centroidal_velocity_jacobian(model=model, data=data)
assert_allclose(J_avg_com_idt, J_avg_com_js)
v_avg_com_idt = kin_dyn.average_centroidal_velocity()
v_avg_com_js = js.com.average_centroidal_velocity(model=model, data=data)
assert_allclose(v_avg_com_idt, v_avg_com_js)
# https://github.com/gbionics/jaxsim/pull/117#discussion_r1535486123
if data.velocity_representation is not VelRepr.Body:
vl_com_idt = kin_dyn.com_velocity()
vl_com_js = js.com.com_linear_velocity(model=model, data=data)
assert_allclose(vl_com_idt, vl_com_js)
# iDynTree provides the bias acceleration in G[W] frame regardless of the velocity
# representation. JaxSim, instead, returns the bias acceleration in G[B] when the
# active representation is VelRepr.Body.
if data.velocity_representation is not VelRepr.Body:
G_v̇_bias_WG_idt = kin_dyn.com_bias_acceleration()
G_v̇_bias_WG_js = js.com.bias_acceleration(model=model, data=data)
assert_allclose(G_v̇_bias_WG_idt, G_v̇_bias_WG_js)
================================================
FILE: tests/test_api_contact.py
================================================
import jax
import jax.numpy as jnp
import rod
import jaxsim.api as js
from jaxsim import VelRepr
from .utils import assert_allclose
def test_contact_kinematics(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model,
key=subkey,
velocity_representation=velocity_representation,
)
# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)
parent_link_idx_of_enabled_collidable_points = jnp.array(
model.kin_dyn_parameters.contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]
# =====
# Tests
# =====
# Compute the pose of the implicit contact frame associated to the collidable points
# and the transforms of all links.
W_H_C = js.contact.transforms(model=model, data=data)
W_H_L = data._link_transforms
# Check that the orientation of the implicit contact frame matches with the
# orientation of the link to which the contact point is attached.
for contact_idx, index_of_parent_link in enumerate(
parent_link_idx_of_enabled_collidable_points
):
assert_allclose(
W_H_C[contact_idx, 0:3, 0:3], W_H_L[index_of_parent_link][0:3, 0:3]
)
# Check that the origin of the implicit contact frame is located over the
# collidable point.
W_p_C = js.contact.collidable_point_positions(model=model, data=data)
assert_allclose(W_p_C, W_H_C[:, 0:3, 3])
# Compute the velocity of the collidable point.
# This quantity always matches with the linear component of the mixed 6D velocity
# of the implicit frame associated to the collidable point.
W_ṗ_C = js.contact.collidable_point_velocities(model=model, data=data)
# Compute the velocity of the collidable point using the contact Jacobian.
ν = data.generalized_velocity
CW_J_WC = js.contact.jacobian(model=model, data=data, output_vel_repr=VelRepr.Mixed)
CW_vl_WC = jnp.einsum("c6g,g->c6", CW_J_WC, ν)[:, 0:3]
# Compare the two velocities.
assert_allclose(W_ṗ_C, CW_vl_WC)
def test_collidable_point_jacobians(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
# =====
# Tests
# =====
# Compute the velocity of the collidable points with a RBDA.
# This function always returns the linear part of the mixed velocity of the
# implicit frame C corresponding to the collidable point.
W_ṗ_C = js.contact.collidable_point_velocities(model=model, data=data)
# Compute the generalized velocity and the free-floating Jacobian of the frame C.
ν = data.generalized_velocity
CW_J_WC = js.contact.jacobian(model=model, data=data, output_vel_repr=VelRepr.Mixed)
# Compute the velocity of the collidable points using the Jacobians.
v_WC_from_jax = jax.vmap(lambda J, ν: J @ ν, in_axes=(0, None))(CW_J_WC, ν)
assert_allclose(W_ṗ_C, v_WC_from_jax[:, 0:3])
def test_contact_jacobian_derivative(
jaxsim_models_floating_base: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):
model = jaxsim_models_floating_base
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model,
key=subkey,
velocity_representation=velocity_representation,
)
# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)
# Extract the parent link names and the poses of the contact points.
parent_link_names = js.link.idxs_to_names(
model=model,
link_indices=jnp.array(
model.kin_dyn_parameters.contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points],
)
L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
indices_of_enabled_collidable_points
]
# =====
# Tests
# =====
# Load the model in ROD.
rod_model = rod.Sdf.load(sdf=model.built_from).model
# Add dummy frames on the contact points.
for idx, link_name, L_p_C in zip(
indices_of_enabled_collidable_points, parent_link_names, L_p_Ci, strict=True
):
rod_model.add_frame(
frame=rod.Frame(
name=f"contact_point_{idx}",
attached_to=link_name,
pose=rod.Pose(
relative_to=link_name, pose=jnp.zeros(shape=(6,)).at[0:3].set(L_p_C)
),
),
)
# Rebuild the JaxSim model.
model_with_frames = js.model.JaxSimModel.build_from_model_description(
model_description=rod_model
)
model_with_frames = js.model.reduce(
model=model_with_frames, considered_joints=model.joint_names()
)
# Rebuild the JaxSim data.
data_with_frames = js.data.JaxSimModelData.build(
model=model_with_frames,
base_position=data.base_position,
base_quaternion=data.base_orientation,
joint_positions=data.joint_positions,
base_linear_velocity=data.base_velocity[0:3],
base_angular_velocity=data.base_velocity[3:6],
joint_velocities=data.joint_velocities,
velocity_representation=velocity_representation,
)
# Extract the indexes of the frames attached to the contact points.
frame_idxs = js.frame.names_to_idxs(
model=model_with_frames,
frame_names=(
f"contact_point_{idx}" for idx in indices_of_enabled_collidable_points
),
)
# Check that the number of frames is correct.
assert len(frame_idxs) == len(parent_link_names)
# Compute the contact Jacobian derivative.
J̇_WC = js.contact.jacobian_derivative(
model=model_with_frames, data=data_with_frames
)
# Compute the contact Jacobian derivative using frames.
J̇_WF = jax.vmap(
js.frame.jacobian_derivative,
in_axes=(None, None),
)(model_with_frames, data, frame_index=frame_idxs)
# Compare the two Jacobians.
assert_allclose(J̇_WC, J̇_WF)
================================================
FILE: tests/test_api_data.py
================================================
import jax
import jax.numpy as jnp
import pytest
from numpy.testing import assert_raises
import jaxsim.api as js
from jaxsim import VelRepr
from jaxsim.utils import Mutability
from . import utils
from .utils import assert_allclose
def test_data_valid(
jaxsim_models_all: js.model.JaxSimModel,
):
model = jaxsim_models_all
data = js.data.JaxSimModelData.build(model=model)
assert data.valid(model=model)
def test_data_switch_velocity_representation(
jaxsim_models_types: js.model.JaxSimModel,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=VelRepr.Inertial
)
# =====
# Tests
# =====
new_base_linear_velocity = jnp.array([1.0, -2.0, 3.0])
old_base_linear_velocity = data._base_linear_velocity
# The following should not change the original `data` object since it raises.
with pytest.raises(RuntimeError):
with data.switch_velocity_representation(
velocity_representation=VelRepr.Inertial
):
with data.mutable_context(mutability=Mutability.MUTABLE):
data._base_linear_velocity = new_base_linear_velocity
raise RuntimeError("This is raised on purpose inside this context")
assert_allclose(data._base_linear_velocity, old_base_linear_velocity)
# The following instead should result to an updated `data` object.
with (
data.switch_velocity_representation(velocity_representation=VelRepr.Inertial),
data.mutable_context(mutability=Mutability.MUTABLE),
):
data._base_linear_velocity = new_base_linear_velocity
assert_allclose(data._base_linear_velocity, new_base_linear_velocity)
def test_data_change_velocity_representation(
jaxsim_models_types: js.model.JaxSimModel,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=VelRepr.Inertial
)
# =====
# Tests
# =====
kin_dyn_inertial = utils.build_kindyncomputations_from_jaxsim_model(
model=model, data=data
)
with data.switch_velocity_representation(VelRepr.Mixed):
kin_dyn_mixed = utils.build_kindyncomputations_from_jaxsim_model(
model=model, data=data
)
with data.switch_velocity_representation(VelRepr.Body):
kin_dyn_body = utils.build_kindyncomputations_from_jaxsim_model(
model=model, data=data
)
assert_allclose(data.base_velocity, kin_dyn_inertial.base_velocity())
if not model.floating_base():
return
with data.switch_velocity_representation(VelRepr.Mixed):
assert_allclose(data.base_velocity, kin_dyn_mixed.base_velocity())
assert_raises(
AssertionError,
assert_allclose,
data.base_velocity[0:3],
data._base_linear_velocity,
)
assert_allclose(data.base_velocity[3:6], data._base_angular_velocity)
with data.switch_velocity_representation(VelRepr.Body):
assert_allclose(data.base_velocity, kin_dyn_body.base_velocity())
assert_raises(
AssertionError,
assert_allclose,
data.base_velocity[0:3],
data._base_linear_velocity,
)
assert_raises(
AssertionError,
assert_allclose,
data.base_velocity[3:6],
data._base_angular_velocity,
)
================================================
FILE: tests/test_api_frame.py
================================================
import jax
import jax.numpy as jnp
import pytest
from jax.errors import JaxRuntimeError
from numpy.testing import assert_array_equal
import jaxsim.api as js
from jaxsim import VelRepr
from jaxsim.math.quaternion import Quaternion
from . import utils
from .utils import assert_allclose
def test_frame_index(jaxsim_models_types: js.model.JaxSimModel):
model = jaxsim_models_types
# =====
# Tests
# =====
n_l = model.number_of_links()
n_f = len(model.frame_names())
for idx, frame_name in enumerate(model.frame_names()):
frame_index = n_l + idx
assert js.frame.name_to_idx(model=model, frame_name=frame_name) == frame_index
assert js.frame.idx_to_name(model=model, frame_index=frame_index) == frame_name
assert (
js.frame.idx_of_parent_link(model=model, frame_index=frame_index)
< model.number_of_links()
)
# See discussion in https://github.com/gbionics/jaxsim/pull/280
assert_array_equal(
js.frame.names_to_idxs(model=model, frame_names=model.frame_names()),
jnp.arange(n_l, n_l + n_f),
)
assert (
js.frame.idxs_to_names(
model=model,
frame_indices=tuple(
js.frame.names_to_idxs(
model=model, frame_names=model.frame_names()
).tolist()
),
)
== model.frame_names()
)
with pytest.raises(ValueError):
_ = js.frame.name_to_idx(model=model, frame_name="non_existent_frame")
with pytest.raises(JaxRuntimeError):
_ = js.frame.idx_to_name(model=model, frame_index=-1)
with pytest.raises(JaxRuntimeError):
_ = js.frame.idx_to_name(model=model, frame_index=n_l - 1)
with pytest.raises(JaxRuntimeError):
_ = js.frame.idx_to_name(model=model, frame_index=n_l + n_f)
with pytest.raises(JaxRuntimeError):
_ = js.frame.idx_of_parent_link(model=model, frame_index=-1)
with pytest.raises(JaxRuntimeError):
_ = js.frame.idx_of_parent_link(model=model, frame_index=n_l - 1)
with pytest.raises(JaxRuntimeError):
_ = js.frame.idx_of_parent_link(model=model, frame_index=n_l + n_f)
def test_frame_transforms(
jaxsim_models_types: js.model.JaxSimModel,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=VelRepr.Inertial
)
kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)
# Get all names of frames in the iDynTree model.
frame_names = [
frame.name
for frame in model.description.frames
if frame.name in kin_dyn.frame_names()
]
# Skip some entry of models with many frames.
frame_names = [
name
for name in frame_names
if "skin" not in name or "laser" not in name or "depth" not in name
]
# Get indices of frames.
frame_indices = tuple(
frame.index
for frame in model.description.frames
if frame.index is not None and frame.name in frame_names
)
# =====
# Tests
# =====
assert len(frame_indices) == len(frame_names)
for frame_name in frame_names:
W_H_F_js = js.frame.transform(
model=model,
data=data,
frame_index=js.frame.name_to_idx(model=model, frame_name=frame_name),
)
W_H_F_idt = kin_dyn.frame_transform(frame_name=frame_name)
assert_allclose(
W_H_F_js, W_H_F_idt, atol=1e-6, err_msg=f"Mismatch in {frame_name}"
)
def test_frame_jacobians(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)
# Get all names of frames in the iDynTree model.
frame_names = [
frame.name
for frame in model.description.frames
if frame.name in kin_dyn.frame_names()
]
# Lower the number of frames for models with many frames.
if model.name().lower() == "ergocub":
assert any("sole" in name for name in frame_names)
frame_names = [name for name in frame_names if "sole" in name]
# Get indices of frames.
frame_indices = tuple(
frame.index
for frame in model.description.frames
if frame.index is not None and frame.name in frame_names
)
# =====
# Tests
# =====
assert len(frame_indices) == len(frame_names)
for frame_name, frame_index in zip(frame_names, frame_indices, strict=True):
J_WL_js = js.frame.jacobian(model=model, data=data, frame_index=frame_index)
J_WL_idt = kin_dyn.jacobian_frame(frame_name=frame_name)
assert_allclose(J_WL_js, J_WL_idt, err_msg=f"Mismatch in {frame_name}")
for frame_name, frame_index in zip(frame_names, frame_indices, strict=True):
v_WF_idt = kin_dyn.frame_velocity(frame_name=frame_name)
v_WF_js = js.frame.velocity(model=model, data=data, frame_index=frame_index)
assert_allclose(v_WF_js, v_WF_idt, err_msg=f"Mismatch in {frame_name}")
def test_frame_jacobian_derivative(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)
# Get all names of frames in the iDynTree model.
frame_names = [
frame.name
for frame in model.description.frames
if frame.name in kin_dyn.frame_names()
]
# Skip some entry of models with many frames.
frame_names = [
name
for name in frame_names
if "skin" not in name or "laser" not in name or "depth" not in name
]
frame_idxs = js.frame.names_to_idxs(model=model, frame_names=tuple(frame_names))
# ===============
# Test against AD
# ===============
# Get the generalized velocity.
I_ν = data.generalized_velocity
# Compute J̇.
O_J̇_WF_I = jax.vmap(
lambda frame_index: js.frame.jacobian_derivative(
model=model, data=data, frame_index=frame_index
)
)(frame_idxs)
assert O_J̇_WF_I.shape == (len(frame_names), 6, 6 + model.dofs())
# Compute the plain Jacobian.
# This function will be used to compute the Jacobian derivative with AD.
def J(q, frame_idxs) -> jax.Array:
data_ad = js.data.JaxSimModelData.build(
model=model,
velocity_representation=data.velocity_representation,
base_position=q[:3],
base_quaternion=q[3:7],
joint_positions=q[7:],
)
O_J_ad_WF_I = jax.vmap(
lambda model, data, frame_index: js.frame.jacobian(
model=model, data=data, frame_index=frame_index
),
in_axes=(None, None, 0),
)(model, data_ad, frame_idxs)
return O_J_ad_WF_I
def compute_q(data: js.data.JaxSimModelData) -> jax.Array:
q = jnp.hstack(
[
data.base_position,
data.base_orientation,
data.joint_positions,
]
)
return q
def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array:
with data.switch_velocity_representation(VelRepr.Body):
B_ω_WB = data.base_velocity[3:6]
with data.switch_velocity_representation(VelRepr.Mixed):
W_ṗ_B = data.base_velocity[0:3]
W_Q̇_B = Quaternion.derivative(
quaternion=data.base_orientation,
omega=B_ω_WB,
omega_in_body_fixed=True,
K=0.0,
).squeeze()
q̇ = jnp.hstack([W_ṗ_B, W_Q̇_B, data.joint_velocities])
return q̇
# Compute q and q̇.
q = compute_q(data)
q̇ = compute_q̇(data)
# Compute dJ/dt with AD.
dJ_dq = jax.jacfwd(J, argnums=0)(q, frame_idxs)
O_J̇_ad_WF_I = jnp.einsum("ijkq,q->ijk", dJ_dq, q̇)
assert_allclose(O_J̇_WF_I, O_J̇_ad_WF_I)
# =====================
# Test against iDynTree
# =====================
# Compute the product J̇ν.
O_a_bias_WF = jax.vmap(
lambda O_J̇_WF_I, I_ν: O_J̇_WF_I @ I_ν,
in_axes=(0, None),
)(O_J̇_WF_I, I_ν)
# Compare the two computations.
for index, name in enumerate(frame_names):
J̇ν_idt = kin_dyn.frame_bias_acc(frame_name=name)
J̇ν_js = O_a_bias_WF[index]
assert_allclose(J̇ν_js, J̇ν_idt, err_msg=f"Mismatch in {name}")
================================================
FILE: tests/test_api_joint.py
================================================
import jax.numpy as jnp
import pytest
from jax.errors import JaxRuntimeError
from numpy.testing import assert_array_equal
import jaxsim.api as js
def test_joint_index(
jaxsim_models_types: js.model.JaxSimModel,
):
model = jaxsim_models_types
# =====
# Tests
# =====
for idx, joint_name in enumerate(model.joint_names()):
assert js.joint.name_to_idx(model=model, joint_name=joint_name) == idx
assert js.joint.idx_to_name(model=model, joint_index=idx) == joint_name
# See discussion in https://github.com/gbionics/jaxsim/pull/280
assert_array_equal(
js.joint.names_to_idxs(model=model, joint_names=model.joint_names()),
jnp.arange(model.number_of_joints()),
)
assert (
js.joint.idxs_to_names(
model=model,
joint_indices=tuple(
js.joint.names_to_idxs(
model=model, joint_names=model.joint_names()
).tolist()
),
)
== model.joint_names()
)
with pytest.raises(ValueError):
_ = js.joint.name_to_idx(model=model, joint_name="non_existent_joint")
with pytest.raises(JaxRuntimeError):
_ = js.joint.idx_to_name(model=model, joint_index=-1)
with pytest.raises(IndexError):
_ = js.joint.idx_to_name(model=model, joint_index=model.number_of_joints())
================================================
FILE: tests/test_api_link.py
================================================
import jax
import jax.numpy as jnp
import pytest
from jax.errors import JaxRuntimeError
from numpy.testing import assert_array_equal
import jaxsim.api as js
import jaxsim.math
from jaxsim import VelRepr
from . import utils
from .utils import assert_allclose
def test_link_index(
jaxsim_models_types: js.model.JaxSimModel,
):
model = jaxsim_models_types
# =====
# Tests
# =====
for idx, link_name in enumerate(model.link_names()):
assert js.link.name_to_idx(model=model, link_name=link_name) == idx
assert js.link.idx_to_name(model=model, link_index=idx) == link_name
# See discussion in https://github.com/gbionics/jaxsim/pull/280
assert_array_equal(
js.link.names_to_idxs(model=model, link_names=model.link_names()),
jnp.arange(model.number_of_links()),
)
assert (
js.link.idxs_to_names(
model=model,
link_indices=tuple(
js.link.names_to_idxs(
model=model, link_names=model.link_names()
).tolist()
),
)
== model.link_names()
)
with pytest.raises(ValueError):
_ = js.link.name_to_idx(model=model, link_name="non_existent_link")
with pytest.raises(JaxRuntimeError):
_ = js.link.idx_to_name(model=model, link_index=-1)
with pytest.raises(IndexError):
_ = js.link.idx_to_name(model=model, link_index=model.number_of_links())
def test_link_inertial_properties(
jaxsim_models_types: js.model.JaxSimModel,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model,
key=subkey,
velocity_representation=VelRepr.Inertial,
)
kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)
# =====
# Tests
# =====
for link_name, link_idx in zip(
model.link_names(),
jnp.arange(model.number_of_links()),
strict=True,
):
if link_name == model.base_link():
continue
assert_allclose(
js.link.mass(model=model, link_index=link_idx),
kin_dyn.link_mass(link_name=link_name),
err_msg=f"Mismatch in {link_name}",
)
assert_allclose(
js.link.spatial_inertia(model=model, link_index=link_idx),
kin_dyn.link_spatial_inertia(link_name=link_name),
err_msg=f"Mismatch in {link_name}",
)
def test_link_transforms(
jaxsim_models_types: js.model.JaxSimModel,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model,
key=subkey,
velocity_representation=VelRepr.Inertial,
)
kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)
# =====
# Tests
# =====
W_H_LL_model = data._link_transforms
W_H_LL_links = jax.vmap(
lambda idx: js.link.transform(model=model, data=data, link_index=idx)
)(jnp.arange(model.number_of_links()))
assert_allclose(W_H_LL_model, W_H_LL_links)
for W_H_L, link_name in zip(W_H_LL_links, model.link_names(), strict=True):
assert_allclose(
W_H_L,
kin_dyn.frame_transform(frame_name=link_name),
err_msg=f"Mismatch in {link_name}",
)
def test_link_jacobians(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model,
key=subkey,
velocity_representation=velocity_representation,
)
kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)
# =====
# Tests
# =====
J_WL_links = jax.vmap(
lambda idx: js.link.jacobian(model=model, data=data, link_index=idx)
)(jnp.arange(model.number_of_links()))
for J_WL, link_name in zip(J_WL_links, model.link_names(), strict=True):
assert_allclose(
J_WL,
kin_dyn.jacobian_frame(frame_name=link_name),
err_msg=f"Mismatch in {link_name}",
)
# The following is true only in inertial-fixed representation.
J_WL_model = js.model.generalized_free_floating_jacobian(model=model, data=data)
assert_allclose(J_WL_model, J_WL_links)
for link_name, link_idx in zip(
model.link_names(),
jnp.arange(model.number_of_links()),
strict=True,
):
v_WL_idt = kin_dyn.frame_velocity(frame_name=link_name)
v_WL_js = js.link.velocity(model=model, data=data, link_index=link_idx)
assert_allclose(v_WL_js, v_WL_idt, err_msg=f"Mismatch in {link_name}")
# Test conversion to a different output velocity representation.
for other_repr in {VelRepr.Inertial, VelRepr.Body, VelRepr.Mixed}.difference(
{data.velocity_representation}
):
with data.switch_velocity_representation(other_repr):
kin_dyn_other_repr = utils.build_kindyncomputations_from_jaxsim_model(
model=model, data=data
)
for link_name, link_idx in zip(
model.link_names(),
jnp.arange(model.number_of_links()),
strict=True,
):
v_WL_idt = kin_dyn_other_repr.frame_velocity(frame_name=link_name)
v_WL_js = js.link.velocity(
model=model, data=data, link_index=link_idx, output_vel_repr=other_repr
)
assert_allclose(v_WL_js, v_WL_idt, err_msg=f"Mismatch in {link_name}")
def test_link_bias_acceleration(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model,
key=subkey,
velocity_representation=velocity_representation,
)
kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)
# =====
# Tests
# =====
for name, index in zip(
model.link_names(),
jnp.arange(model.number_of_links()),
strict=True,
):
Jν_idt = kin_dyn.frame_bias_acc(frame_name=name)
Jν_js = js.link.bias_acceleration(model=model, data=data, link_index=index)
assert_allclose(Jν_js, Jν_idt, err_msg=f"Mismatch in {name}")
# Test that the conversion of the link bias acceleration works as expected.
match data.velocity_representation:
# We exclude the mixed representation because converting the acceleration is
# more complex than using the plain 6D transform matrix.
case VelRepr.Mixed:
pass
# Inertial-fixed to body-fixed conversion.
case VelRepr.Inertial:
W_H_L = data._link_transforms
W_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data)
with data.switch_velocity_representation(VelRepr.Body):
W_X_L = jax.vmap(
lambda W_H_L: jaxsim.math.Adjoint.from_transform(transform=W_H_L)
)(W_H_L)
L_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data)
W_a_bias_WL_converted = jax.vmap(
lambda W_X_L, L_a_bias_WL: W_X_L @ L_a_bias_WL
)(W_X_L, L_a_bias_WL)
assert_allclose(W_a_bias_WL, W_a_bias_WL_converted)
# Body-fixed to inertial-fixed conversion.
case VelRepr.Body:
W_H_L = data._link_transforms
L_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data)
with data.switch_velocity_representation(VelRepr.Inertial):
L_X_W = jax.vmap(
lambda W_H_L: jaxsim.math.Adjoint.from_transform(
transform=W_H_L, inverse=True
)
)(W_H_L)
W_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data)
L_a_bias_WL_converted = jax.vmap(
lambda L_X_W, W_a_bias_WL: L_X_W @ W_a_bias_WL
)(L_X_W, W_a_bias_WL)
assert_allclose(L_a_bias_WL, L_a_bias_WL_converted)
def test_link_jacobian_derivative(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model,
key=subkey,
velocity_representation=velocity_representation,
)
# =====
# Tests
# =====
# Get the generalized velocity.
I_ν = data.generalized_velocity
# Compute J̇.
O_J̇_WL_I = jax.vmap(
lambda link_index: js.link.jacobian_derivative(
model=model, data=data, link_index=link_index
)
)(jnp.arange(model.number_of_links()))
# Compute the product J̇ν.
O_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data)
# Compare the two computations.
assert_allclose(jnp.einsum("l6g,g->l6", O_J̇_WL_I, I_ν), O_a_bias_WL)
# Compute the plain Jacobian.
# This function will be used to compute the Jacobian derivative with AD.
# Given q, computing J̇ by AD-ing this function should work out-of-the-box with
# all velocity representations, that are handled internally when computing J.
def J(q) -> jax.Array:
data_ad = js.data.JaxSimModelData.build(
model=model,
velocity_representation=data.velocity_representation,
base_position=q[:3],
base_quaternion=q[3:7],
joint_positions=q[7:],
)
O_J_WL_I = js.model.generalized_free_floating_jacobian(
model=model, data=data_ad
)
return O_J_WL_I
def compute_q(data: js.data.JaxSimModelData) -> jax.Array:
q = jnp.hstack(
[data.base_position, data.base_orientation, data.joint_positions]
)
return q
def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array:
with data.switch_velocity_representation(VelRepr.Body):
B_ω_WB = data.base_velocity[3:6]
with data.switch_velocity_representation(VelRepr.Mixed):
W_ṗ_B = data.base_velocity[0:3]
W_Q̇_B = jaxsim.math.Quaternion.derivative(
quaternion=data.base_orientation,
omega=B_ω_WB,
omega_in_body_fixed=True,
K=0.0,
).squeeze()
q̇ = jnp.hstack([W_ṗ_B, W_Q̇_B, data.joint_velocities])
return q̇
# Compute q and q̇.
q = compute_q(data)
q̇ = compute_q̇(data)
# Compute dJ/dt with AD.
dJ_dq = jax.jacfwd(J, argnums=0)(q)
O_J̇_ad_WL_I = jnp.einsum("ijkq,q->ijk", dJ_dq, q̇)
assert_allclose(O_J̇_WL_I, O_J̇_ad_WL_I)
assert_allclose(
jnp.einsum("l6g,g->l6", O_J̇_ad_WL_I, I_ν),
jnp.einsum("l6g,g->l6", O_J̇_WL_I, I_ν),
)
================================================
FILE: tests/test_api_model.py
================================================
import pathlib
import jax
import jax.numpy as jnp
import numpy as np
import rod
import jaxsim.api as js
import jaxsim.math
from jaxsim import VelRepr
from . import utils
from .utils import assert_allclose
def test_model_creation_and_reduction(
jaxsim_model_ergocub: js.model.JaxSimModel,
prng_key: jax.Array,
):
model_full = jaxsim_model_ergocub
_, subkey = jax.random.split(prng_key, num=2)
data_full = js.data.random_model_data(
model=model_full,
key=subkey,
velocity_representation=VelRepr.Inertial,
base_pos_bounds=((0, 0, 0.8), (0, 0, 0.8)),
)
# =====
# Tests
# =====
# Check that the data of the full model is valid.
assert data_full.valid(model=model_full)
# Build the ROD model from the original description.
assert isinstance(model_full.built_from, str | pathlib.Path)
rod_sdf = rod.Sdf.load(sdf=model_full.built_from)
assert len(rod_sdf.models()) == 1
# Get all non-fixed joint names from the description.
joint_names_in_description = [
j.name for j in rod_sdf.models()[0].joints() if j.type != "fixed"
]
# Check that all non-fixed joints are in the full model.
assert set(joint_names_in_description) == set(model_full.joint_names())
# ================
# Reduce the model
# ================
# Get the names of the joints to keep in the reduced model.
reduced_joints = tuple(
j
for j in model_full.joint_names()
if "camera" not in j
and "neck" not in j
and "wrist" not in j
and "thumb" not in j
and "index" not in j
and "middle" not in j
and "ring" not in j
and "pinkie" not in j
#
and "elbow" not in j
and "shoulder" not in j
and "torso" not in j
and "r_knee" not in j
)
# Reduce the model.
# Note: here we also specify a non-zero position of the removed joints.
# The process should take into account the corresponding joint transforms
# when the link-joint-link chains are lumped together.
model_reduced = js.model.reduce(
model=model_full,
considered_joints=reduced_joints,
locked_joint_positions=dict(
zip(
model_full.joint_names(),
data_full.joint_positions.tolist(),
strict=True,
)
),
)
# Check DoFs.
assert model_full.dofs() != model_reduced.dofs()
# Check that all non-fixed joints are in the reduced model.
assert set(reduced_joints) == set(model_reduced.joint_names())
# Check that the reduced model maintains the same terrain of the full model.
assert model_full.terrain == model_reduced.terrain
# Check that the reduced model maintains the same contact model of the full model.
assert model_full.contact_model == model_reduced.contact_model
# Check that the reduced model maintains the same integration step of the full model.
assert model_full.time_step == model_reduced.time_step
joint_idxs = js.joint.names_to_idxs(
model=model_full, joint_names=model_reduced.joint_names()
)
# Build the data of the reduced model.
data_reduced = js.data.JaxSimModelData.build(
model=model_reduced,
base_position=data_full.base_position,
base_quaternion=data_full.base_orientation,
joint_positions=data_full.joint_positions[joint_idxs],
base_linear_velocity=data_full.base_velocity[0:3],
base_angular_velocity=data_full.base_velocity[3:6],
joint_velocities=data_full.joint_velocities[joint_idxs],
velocity_representation=data_full.velocity_representation,
)
# Check that the reduced model data is valid.
assert not data_reduced.valid(model=model_full)
assert data_reduced.valid(model=model_reduced)
# Check that the total mass is preserved.
assert_allclose(
js.model.total_mass(model=model_full), js.model.total_mass(model=model_reduced)
)
# Check that the CoM position is preserved.
assert_allclose(
js.com.com_position(model=model_full, data=data_full),
js.com.com_position(model=model_reduced, data=data_reduced),
atol=1e-6,
)
# Check that joint serialization works.
assert_allclose(data_full.joint_positions[joint_idxs], data_reduced.joint_positions)
assert_allclose(
data_full.joint_velocities[joint_idxs], data_reduced.joint_velocities
)
# Check that link transforms are preserved.
for link_name in model_reduced.link_names():
W_H_L_full = js.link.transform(
model=model_full,
data=data_full,
link_index=js.link.name_to_idx(model=model_full, link_name=link_name),
)
W_H_L_reduced = js.link.transform(
model=model_reduced,
data=data_reduced,
link_index=js.link.name_to_idx(model=model_reduced, link_name=link_name),
)
assert_allclose(W_H_L_full, W_H_L_reduced)
# Check that collidable point positions are preserved.
assert_allclose(
js.contact.collidable_point_positions(model=model_full, data=data_full),
js.contact.collidable_point_positions(model=model_reduced, data=data_reduced),
)
# =====================
# Test against iDynTree
# =====================
kin_dyn_full = utils.build_kindyncomputations_from_jaxsim_model(
model=model_full, data=data_full
)
kin_dyn_reduced = utils.build_kindyncomputations_from_jaxsim_model(
model=model_reduced, data=data_reduced
)
# Check that the total mass is preserved.
assert_allclose(kin_dyn_full.total_mass(), kin_dyn_reduced.total_mass())
# Check that the CoM position match.
assert_allclose(kin_dyn_full.com_position(), kin_dyn_reduced.com_position())
assert_allclose(
kin_dyn_full.com_position(),
js.com.com_position(model=model_reduced, data=data_reduced),
)
# Check that link transforms match.
for link_name in model_reduced.link_names():
assert_allclose(
kin_dyn_reduced.frame_transform(frame_name=link_name),
kin_dyn_full.frame_transform(frame_name=link_name),
err_msg=f"Mismatch in link {link_name}",
)
assert_allclose(
kin_dyn_reduced.frame_transform(frame_name=link_name),
js.link.transform(
model=model_reduced,
data=data_reduced,
link_index=js.link.name_to_idx(
model=model_reduced, link_name=link_name
),
),
err_msg=f"Mismatch in link {link_name}",
)
# Check that frame transforms match.
for frame_name in model_reduced.frame_names():
if frame_name not in kin_dyn_reduced.frame_names():
continue
# Skip some entry of models with many frames.
if "skin" in frame_name or "laser" in frame_name or "depth" in frame_name:
continue
assert_allclose(
kin_dyn_reduced.frame_transform(frame_name=frame_name),
kin_dyn_full.frame_transform(frame_name=frame_name),
err_msg=f"Mismatch in frame {frame_name}",
)
assert_allclose(
kin_dyn_reduced.frame_transform(frame_name=frame_name),
js.frame.transform(
model=model_reduced,
data=data_reduced,
frame_index=js.frame.name_to_idx(
model=model_reduced, frame_name=frame_name
),
),
err_msg=f"Mismatch in frame {frame_name}",
)
def test_model_properties(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)
# =====
# Tests
# =====
m_idt = kin_dyn.total_mass()
m_js = js.model.total_mass(model=model)
assert_allclose(m_idt, m_js)
J_Bh_idt = kin_dyn.total_momentum_jacobian()
J_Bh_js = js.model.total_momentum_jacobian(model=model, data=data)
assert_allclose(J_Bh_idt, J_Bh_js)
h_tot_idt = kin_dyn.total_momentum()
h_tot_js = js.model.total_momentum(model=model, data=data)
assert_allclose(h_tot_idt, h_tot_js)
M_locked_idt = kin_dyn.locked_spatial_inertia()
M_locked_js = js.model.locked_spatial_inertia(model=model, data=data)
assert_allclose(M_locked_idt, M_locked_js)
J_avg_idt = kin_dyn.average_velocity_jacobian()
J_avg_js = js.model.average_velocity_jacobian(model=model, data=data)
assert_allclose(J_avg_idt, J_avg_js)
v_avg_idt = kin_dyn.average_velocity()
v_avg_js = js.model.average_velocity(model=model, data=data)
assert_allclose(v_avg_idt, v_avg_js)
def test_model_rbda(
jaxsim_models_types: js.model.JaxSimModel,
prng_key: jax.Array,
velocity_representation: VelRepr,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)
# =====
# Tests
# =====
# Support both fixed-base and floating-base models by slicing the first six rows.
sl = np.s_[0:] if model.floating_base() else np.s_[6:]
# Mass matrix
M_idt = kin_dyn.mass_matrix()
M_js = js.model.free_floating_mass_matrix(model=model, data=data)
assert_allclose(M_idt[sl, sl], M_js[sl, sl])
# Gravity forces
g_idt = kin_dyn.gravity_forces()
g_js = js.model.free_floating_gravity_forces(model=model, data=data)
assert_allclose(g_idt[sl], g_js[sl])
# Bias forces
h_idt = kin_dyn.bias_forces()
h_js = js.model.free_floating_bias_forces(model=model, data=data)
assert_allclose(h_idt[sl], h_js[sl])
# Forward kinematics
HH_js = data._link_transforms
HH_idt = jnp.stack(
[kin_dyn.frame_transform(frame_name=name) for name in model.link_names()]
)
assert_allclose(HH_idt, HH_js)
# Bias accelerations
Jν_js = js.model.link_bias_accelerations(model=model, data=data)
Jν_idt = jnp.stack(
[kin_dyn.frame_bias_acc(frame_name=name) for name in model.link_names()]
)
assert_allclose(Jν_idt, Jν_js)
# Mass matrix inverse via RBDA
M_inv_js = js.model.free_floating_mass_matrix_inverse(model=model, data=data)
M_inv_idt = jnp.linalg.inv(M_idt)
assert_allclose(M_inv_idt[sl, sl], M_inv_js[sl, sl])
def test_model_jacobian(
jaxsim_models_types: js.model.JaxSimModel,
prng_key: jax.Array,
):
model = jaxsim_models_types
key, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=VelRepr.Inertial
)
# =====
# Tests
# =====
# Create random references (joint torques and link forces)
_, subkey1, subkey2 = jax.random.split(key, num=3)
references = js.references.JaxSimModelReferences.build(
model=model,
joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)),
link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)),
data=data,
velocity_representation=data.velocity_representation,
)
# Remove the force applied to the base link if the model is fixed-base
if not model.floating_base():
references = references.apply_link_forces(
forces=jnp.atleast_2d(jnp.zeros(6)),
model=model,
data=data,
link_names=(model.base_link(),),
additive=False,
)
# Get the J.T @ f product in inertial-fixed input/output representation.
# We use doubly right-trivialized jacobian with inertial-fixed 6D forces.
with (
references.switch_velocity_representation(VelRepr.Inertial),
data.switch_velocity_representation(VelRepr.Inertial),
):
f = references.link_forces(model=model, data=data)
assert_allclose(f, references._link_forces)
J = js.model.generalized_free_floating_jacobian(model=model, data=data)
JTf_inertial = jnp.einsum("l6g,l6->g", J, f)
for vel_repr in (VelRepr.Body, VelRepr.Mixed):
with references.switch_velocity_representation(vel_repr):
# Get the jacobian having an inertial-fixed input representation (so that
# it computes the same quantity computed above) and an output representation
# compatible with the frame in which the external forces are expressed.
with data.switch_velocity_representation(VelRepr.Inertial):
J = js.model.generalized_free_floating_jacobian(
model=model, data=data, output_vel_repr=vel_repr
)
# Get the forces in the tested representation and compute the product
# O_J_WL_W.T @ O_f, producing a generalized acceleration in W.
# The resulting acceleration can be tested again the one computed before.
with data.switch_velocity_representation(vel_repr):
f = references.link_forces(model=model, data=data)
JTf_other = jnp.einsum("l6g,l6->g", J, f)
assert_allclose(JTf_inertial, JTf_other, err_msg=vel_repr.name)
def test_coriolis_matrix(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
# =====
# Tests
# =====
I_ν = data.generalized_velocity
C = js.model.free_floating_coriolis_matrix(model=model, data=data)
h = js.model.free_floating_bias_forces(model=model, data=data)
g = js.model.free_floating_gravity_forces(model=model, data=data)
Cν = h - g
assert_allclose(C @ I_ν, Cν)
# Compute the free-floating mass matrix.
# This function will be used to compute the Ṁ with AD.
# Given q, computing Ṁ by AD-ing this function should work out-of-the-box with
# all velocity representations, that are handled internally when computing M.
def M(q) -> jax.Array:
data_ad = js.data.JaxSimModelData.build(
model=model,
velocity_representation=data.velocity_representation,
base_position=q[:3],
base_quaternion=q[3:7],
joint_positions=q[7:],
)
M = js.model.free_floating_mass_matrix(model=model, data=data_ad)
return M
def compute_q(data: js.data.JaxSimModelData) -> jax.Array:
q = jnp.hstack(
[data.base_position, data.base_orientation, data.joint_positions]
)
return q
def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array:
with data.switch_velocity_representation(VelRepr.Body):
B_ω_WB = data.base_velocity[3:6]
with data.switch_velocity_representation(VelRepr.Mixed):
W_ṗ_B = data.base_velocity[0:3]
W_Q̇_B = jaxsim.math.Quaternion.derivative(
quaternion=data.base_orientation,
omega=B_ω_WB,
omega_in_body_fixed=True,
K=0.0,
).squeeze()
q̇ = jnp.hstack([W_ṗ_B, W_Q̇_B, data.joint_velocities])
return q̇
# Compute q and q̇.
q = compute_q(data)
q̇ = compute_q̇(data)
# Compute Ṁ with AD.
dM_dq = jax.jacfwd(M, argnums=0)(q)
Ṁ = jnp.einsum("ijq,q->ij", dM_dq, q̇)
# We need to zero the blocks projecting joint variables to the base configuration
# for fixed-base models.
if not model.floating_base():
Ṁ = Ṁ.at[0:6, 6:].set(0)
Ṁ = Ṁ.at[6:, 0:6].set(0)
# Ensure that (Ṁ - 2C) is skew symmetric.
assert_allclose(Ṁ - C - C.T, 0.0)
def test_model_fd_id_consistency(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):
model = jaxsim_models_types
key, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
# =====
# Tests
# =====
# Create random references (joint torques and link forces).
_, subkey1, subkey2 = jax.random.split(key, num=3)
references = js.references.JaxSimModelReferences.build(
model=model,
joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)),
link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)),
data=data,
velocity_representation=data.velocity_representation,
)
# Remove the force applied to the base link if the model is fixed-base.
if not model.floating_base():
references = references.apply_link_forces(
forces=jnp.atleast_2d(jnp.zeros(6)),
model=model,
data=data,
link_names=(model.base_link(),),
additive=False,
)
# Compute forward dynamics with ABA.
v̇_WB_aba, s̈_aba = js.model.forward_dynamics_aba(
model=model,
data=data,
joint_forces=references.joint_force_references(),
link_forces=references.link_forces(model=model, data=data),
)
# Compute forward dynamics with CRB.
v̇_WB_crb, s̈_crb = js.model.forward_dynamics_crb(
model=model,
data=data,
joint_forces=references.joint_force_references(),
link_forces=references.link_forces(model=model, data=data),
)
assert_allclose(s̈_aba, s̈_crb)
assert_allclose(v̇_WB_aba, v̇_WB_crb)
# Compute inverse dynamics with the quantities computed by forward dynamics
fB_id, τ_id = js.model.inverse_dynamics(
model=model,
data=data,
joint_accelerations=s̈_aba,
base_acceleration=v̇_WB_aba,
link_forces=references.link_forces(model=model, data=data),
)
# Check consistency between FD and ID
assert_allclose(τ_id, references.joint_force_references(model=model))
assert_allclose(fB_id, 0.0)
if model.floating_base():
# If we remove the base 6D force from the inputs, we should find it as output.
fB_id, τ_id = js.model.inverse_dynamics(
model=model,
data=data,
joint_accelerations=s̈_aba,
base_acceleration=v̇_WB_aba,
link_forces=references.link_forces(model=model, data=data)
.at[0]
.set(jnp.zeros(6)),
)
assert_allclose(τ_id, references.joint_force_references(model=model))
assert_allclose(fB_id, references.link_forces(model=model, data=data)[0])
def test_aba_vs_aba_parallel(
jaxsim_models_all: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):
"""
Verify that the level-parallel ABA produces identical results to
the sequential ABA, both at the low-level RBDA and at the high-level
model API.
"""
model = jaxsim_models_all
key, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
# Create random references.
_, subkey1, subkey2 = jax.random.split(key, num=3)
references = js.references.JaxSimModelReferences.build(
model=model,
joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)),
link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)),
data=data,
velocity_representation=data.velocity_representation,
)
if not model.floating_base():
references = references.apply_link_forces(
forces=jnp.atleast_2d(jnp.zeros(6)),
model=model,
data=data,
link_names=(model.base_link(),),
additive=False,
)
joint_forces = references.joint_force_references()
link_forces = references.link_forces(model=model, data=data)
v̇_WB_seq, s̈_seq = js.model.forward_dynamics_aba(
model=model,
data=data,
joint_forces=joint_forces,
link_forces=link_forces,
parallel=False,
)
v̇_WB_par, s̈_par = js.model.forward_dynamics_aba(
model=model,
data=data,
joint_forces=joint_forces,
link_forces=link_forces,
parallel=True,
)
assert_allclose(v̇_WB_seq, v̇_WB_par, atol=1e-9)
assert_allclose(s̈_seq, s̈_par, atol=1e-9)
def test_fk_vs_fk_parallel(
jaxsim_models_all: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):
"""
Verify that the level-parallel FK produces identical results to
the sequential FK.
"""
model = jaxsim_models_all
_, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
W_H_seq = js.model.forward_kinematics(model=model, data=data, parallel=False)
W_H_par = js.model.forward_kinematics(model=model, data=data, parallel=True)
assert_allclose(W_H_seq, W_H_par, atol=1e-9)
================================================
FILE: tests/test_api_model_hw_parametrization.py
================================================
import pathlib
import xml.etree.ElementTree as ET
import jax
import jax.numpy as jnp
import numpy as np
import pytest
import rod
import jaxsim.api as js
from jaxsim.api.kin_dyn_parameters import (
HwLinkMetadata,
LinkParametrizableShape,
ScalingFactors,
)
from jaxsim.rbda.contacts import SoftContactsParams
from .utils import assert_allclose
def test_update_hw_link_parameters(jaxsim_model_garpez: js.model.JaxSimModel):
"""
Test that the hardware parameters of the model are updated correctly.
"""
model = jaxsim_model_garpez
# Store initial hardware parameters
initial_metadata = model.kin_dyn_parameters.hw_link_metadata
# Create the scaling factors
scaling_parameters = ScalingFactors(
dims=jnp.array(
[
[2.0, 1.5, 1.0], # Scale x, y, z for link1
[1.2, 1.0, 1.0], # Scale r for link2
[1.5, 0.8, 1.0], # Scale r, l for link3
[1.5, 1.0, 0.8], # Scale x, y, z for link4
]
),
density=jnp.ones(4),
)
# Update the model using the scaling factors
updated_model = js.model.update_hw_parameters(model, scaling_parameters)
# Compare updated hardware parameters
for link_idx, link_name in enumerate(model.link_names()):
updated_metadata = jax.tree.map(
lambda x, link_idx=link_idx: x[link_idx],
updated_model.kin_dyn_parameters.hw_link_metadata,
)
initial_metadata_link = jax.tree.map(
lambda x, link_idx=link_idx: x[link_idx], initial_metadata
)
# TODO: Compute the 3D scaling vector
# scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector(
# initial_metadata_link.shape, scaling_parameters.dims[link_idx]
# )
expected_link_dimensions = (
initial_metadata_link.geometry * scaling_parameters.dims[link_idx]
)
# Compare shape dimensions
assert_allclose(
updated_metadata.geometry,
expected_link_dimensions,
atol=1e-6,
err_msg=f"Mismatch in dimensions for link {link_name}",
)
@pytest.mark.parametrize(
"jaxsim_model_garpez_scaled",
[
{
"link1_scale": 4.0,
"link2_scale": 3.0,
"link3_scale": 2.0,
"link4_scale": 1.5,
}
],
indirect=True,
)
def test_model_scaling_against_rod(
jaxsim_model_garpez: js.model.JaxSimModel,
jaxsim_model_garpez_scaled: js.model.JaxSimModel,
):
"""
Test that scaling the HW parameters of JaxSim model matches the kin/dyn quantities of a JaxSim model obtained from a pre-scaled rod model.
"""
# Define scaling parameters
scaling_parameters = ScalingFactors(
dims=jnp.array(
[
[4.0, 1.0, 1.0], # Scale only x-dimension for link1
[3.0, 1.0, 1.0], # Scale only r-dimension for link2
[1.0, 2.0, 1.0], # Scale l dimension for link3
[1.5, 1.0, 1.0], # Scale only x-dimension for link4
]
),
density=jnp.ones(4),
)
# Apply scaling to the original JaxSim model
updated_model = js.model.update_hw_parameters(
jaxsim_model_garpez, scaling_parameters
)
# Compare hardware parameters of the scaled JaxSim model with the pre-scaled JaxSim model
scaled_metadata = updated_model.kin_dyn_parameters.hw_link_metadata
pre_scaled_metadata = jaxsim_model_garpez_scaled.kin_dyn_parameters.hw_link_metadata
# Compare shape dimensions
assert_allclose(scaled_metadata.geometry, pre_scaled_metadata.geometry, atol=1e-6)
# Compare mass
scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia(scaled_metadata)
pre_scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia(pre_scaled_metadata)
assert_allclose(scaled_mass, pre_scaled_mass, atol=1e-6)
# Compare inertia tensors
_, scaled_inertia = HwLinkMetadata.compute_mass_and_inertia(scaled_metadata)
_, pre_scaled_inertia = HwLinkMetadata.compute_mass_and_inertia(pre_scaled_metadata)
assert_allclose(scaled_inertia, pre_scaled_inertia, atol=1e-6)
# Compare transformations
assert_allclose(scaled_metadata.L_H_G, pre_scaled_metadata.L_H_G, atol=1e-6)
assert_allclose(scaled_metadata.L_H_vis, pre_scaled_metadata.L_H_vis, atol=1e-6)
# Compare collidable points positions
assert_allclose(
jaxsim_model_garpez_scaled.kin_dyn_parameters.contact_parameters.point,
updated_model.kin_dyn_parameters.contact_parameters.point,
atol=1e-6,
)
def test_update_hw_parameters_vmap(
jaxsim_model_garpez: js.model.JaxSimModel,
):
"""
Test that the hardware parameters of the model are updated correctly using vmap
to create a set of n updated models.
"""
model_nominal = jaxsim_model_garpez
dofs = model_nominal.dofs()
# Define a set of scaling factors for n models
n = 10 # Number of updated models to create
scaling_factors = [
ScalingFactors(
dims=(scale * jnp.ones((model_nominal.number_of_links(), 3))),
density=(scale * jnp.ones(model_nominal.number_of_links())),
)
for scale in jnp.linspace(2.0, 2.0 + n - 1, n)
]
# Convert the list of ScalingFactors to a JAX array of pytrees
scaling_factors = jax.tree.map(lambda *l: jnp.stack(l), *scaling_factors)
# Generate a batch of updated models using vmap
updated_models = jax.vmap(js.model.update_hw_parameters, in_axes=(None, 0))(
model_nominal,
scaling_factors,
)
def validate_model(updated_model):
assert updated_model is not None
# Compute forward kinematics for the "link3" link
H_link3 = js.link.transform(
model=updated_model,
data=js.data.JaxSimModelData.build(model=updated_model),
link_index=js.link.name_to_idx(model=updated_model, link_name="link3"),
)
# Compute the mass matrix
M = js.model.free_floating_mass_matrix(
model=updated_model,
data=js.data.JaxSimModelData.build(model=updated_model),
)
assert H_link3 is not None
assert H_link3.shape == (4, 4)
assert M is not None
assert isinstance(M, jnp.ndarray)
assert M.shape == (6 + dofs, 6 + dofs)
# Use vmap to validate all updated models
jax.vmap(validate_model)(updated_models)
@pytest.mark.parametrize(
"jaxsim_model_garpez_scaled",
[
{
"link1_scale": 4.0,
"link2_scale": 3.0,
"link3_scale": 2.0,
"link4_scale": 1.5,
}
],
indirect=True,
)
def test_export_updated_model(
jaxsim_model_garpez: js.model.JaxSimModel,
jaxsim_model_garpez_scaled: js.model.JaxSimModel,
):
"""
Test the export of an updated model using JaxSimModel.export_updated_model.
"""
model: js.model.JaxSimModel = jaxsim_model_garpez
# Define scaling parameters
scaling_parameters = ScalingFactors(
dims=jnp.array(
[
[4.0, 1.0, 1.0], # Scale x-dimension for link1
[3.0, 1.0, 1.0], # Scale r-dimension for link2
[1.0, 2.0, 1.0], # Scale l-dimension for link3
[1.5, 1.0, 1.0], # Scale x-dimension for link4
]
),
density=jnp.ones(4),
)
identity_scaling = ScalingFactors(
dims=jnp.ones((model.number_of_links(), 3)),
density=jnp.ones(model.number_of_links()),
)
def get_link_by_name(model, name):
try:
return next(link for link in model.links() if link.name == name)
except StopIteration as err:
raise ValueError(
f"Link '{name}' not found. Available links: {[l.name for l in model.links()]}"
) from err
def compare_geometries(exported_link, ref_link, label=""):
exported_geom = exported_link.visual.geometry.geometry()
ref_geom = ref_link.visual.geometry.geometry()
attrs = [attr for attr in vars(exported_geom) if hasattr(ref_geom, attr)]
exported_vals = jnp.array([getattr(exported_geom, attr) for attr in attrs])
ref_vals = jnp.array([getattr(ref_geom, attr) for attr in attrs])
assert_allclose(
exported_vals,
ref_vals,
err_msg=f"Geometry mismatch in {label} model.",
atol=1e-6,
)
def compare_mass_and_inertia(exported_link, ref_link, label=""):
assert_allclose(
exported_link.inertial.mass,
ref_link.inertial.mass,
atol=1e-4,
err_msg=f"Mass mismatch in {label} model.",
)
assert_allclose(
exported_link.inertial.inertia.matrix(),
ref_link.inertial.inertia.matrix(),
atol=1e-4,
err_msg=f"Inertia matrix mismatch in {label} model.",
)
def compare_collisions(exported_link, ref_link, label=""):
geom_types = ["box", "sphere", "cylinder"]
for geom_type in geom_types:
exp_geom = getattr(exported_link.collision.geometry, geom_type)
ref_geom = getattr(ref_link.collision.geometry, geom_type)
if ref_geom is not None:
if geom_type == "box":
assert_allclose(
jnp.array(exp_geom.size), jnp.array(ref_geom.size), atol=1e-6
)
elif geom_type == "sphere":
assert_allclose(exp_geom.radius, ref_geom.radius, atol=1e-6)
elif geom_type == "cylinder":
assert_allclose(exp_geom.radius, ref_geom.radius, atol=1e-6)
assert_allclose(exp_geom.length, ref_geom.length, atol=1e-6)
return
pytest.skip(
f"Collision geometry type for link {exported_link.name} not supported."
)
def validate_model(updated_model, ref_model, label):
urdf = updated_model.export_updated_model()
assert isinstance(urdf, str), f"{label}: Exported URDF is not a string."
exported_sdf = rod.Sdf.load(urdf, is_urdf=True)
assert (
len(exported_sdf.models()) == 1
), f"{label}: Exported model does not contain exactly one ROD model."
exported_model = exported_sdf.models()[0]
for link_name in model.link_names():
exported_link = get_link_by_name(exported_model, link_name)
ref_link = get_link_by_name(ref_model, link_name)
compare_geometries(exported_link, ref_link, label=label)
compare_mass_and_inertia(exported_link, ref_link, label=label)
compare_collisions(exported_link, ref_link, label=label)
# Test both scaled and identity-scaled updates
for scaling, label in (
(scaling_parameters, "SCALED"),
(identity_scaling, "IDENTITY SCALED"),
):
# Load reference ROD model
if label == "IDENTITY SCALED":
ref_model = rod.Sdf.load(jaxsim_model_garpez.built_from).models()[0]
else:
ref_model = rod.Sdf.load(jaxsim_model_garpez_scaled.built_from).models()[0]
updated_model = js.model.update_hw_parameters(model, scaling)
validate_model(updated_model, ref_model, label)
def test_hw_parameters_optimization(jaxsim_model_garpez: js.model.JaxSimModel):
"""
Test that updating hardware parameters allows optimizing the position of a link
to match a target value along a specific world axis.
"""
model = jaxsim_model_garpez
data = js.data.JaxSimModelData.build(model=model)
# Define the target height for the link.
target_height = 3.0
# Get the index of the link to optimize (e.g., "torso").
link_idx = js.link.name_to_idx(model, link_name="link4")
# Define the initial hardware parameters (scaling factors).
initial_dims = jnp.ones(
(model.number_of_links(), 3)
) # Initial dimensions (1.0 for all links).
initial_density = jnp.ones(
(model.number_of_links(),)
) # Initial density (1.0 for all links).
scaling_factors = js.kin_dyn_parameters.ScalingFactors(
dims=initial_dims, density=initial_density
)
# Define the loss function.
def loss(scaling_factors):
# Update the model with the new hardware parameters.
updated_model = js.model.update_hw_parameters(
model=model, scaling_factors=scaling_factors
)
# Sync the data's cached kinematics (joint transforms, link transforms, …)
# with the updated model geometry before running any dynamics.
updated_data = data.replace(model=updated_model)
# Compute forward kinematics for the link.
W_H_L = js.model.forward_kinematics(model=updated_model, data=updated_data)[
link_idx
]
# Extract the height (z-axis position) of the link.
link4_height = W_H_L[2, 3] # Assuming z-axis is the third row.
# Compute the loss as the squared difference from the target height.
return (link4_height - target_height) ** 2
# Compute the gradient of the loss function with respect to the scaling factors.
loss_grad = jax.grad(loss)
# Perform gradient descent.
alpha = 0.01 # Learning rate.
num_iterations = 1000 # Number of gradient descent steps.
for _ in range(num_iterations):
# Compute the gradient.
grad_scaling_factors = loss_grad(scaling_factors)
# Update the scaling factors.
scaling_factors = js.kin_dyn_parameters.ScalingFactors(
dims=scaling_factors.dims - alpha * grad_scaling_factors.dims,
density=scaling_factors.density - alpha * grad_scaling_factors.density,
)
# Compute the current loss value.
current_loss = loss(scaling_factors)
# Optionally, print the progress.
if _ % 100 == 0:
print(f"Iteration {_}: Loss = {current_loss}")
# Assert that the final loss is close to zero.
assert current_loss < 1e-3, "Optimization did not converge to the target height."
def test_hw_parameters_collision_scaling(
jaxsim_model_box: js.model.JaxSimModel, prng_key: jax.Array
):
"""
Test that the collision elements of the model are updated correctly during the scaling of the model hw parameters.
"""
_, subkey = jax.random.split(prng_key, num=2)
# TODO: the jaxsim_model_box has an additional frame, which is handled wrongly
# during the export of the updated model. For this reason, we recreate the model
# from scratch here.
del jaxsim_model_box
import rod.builder.primitives
# Create on-the-fly a ROD model of a box.
rod_model = (
rod.builder.primitives.BoxBuilder(x=0.3, y=0.2, z=0.1, mass=1.0, name="box")
.build_model()
.add_link(name="box_link")
.add_inertial()
.add_visual()
.add_collision()
.build()
)
model = js.model.JaxSimModel.build_from_model_description(
model_description=rod_model
)
# Define the scaling factor for the model
scaling_factor = 5.0
# Recompute K and D, since the mass is scaled by scaling_factor^3
# and the expected static compression of the terrain is approximately
# proportional to mass/K and divided by the 4 contact points.
K = model.contact_params.K * (scaling_factor**2)
# Strongly overdamped, to avoid oscillations due to the high mass
# and the low penetration allowed.
D = 8 * jnp.sqrt(K)
with model.editable(validate=False) as model:
model.contact_params = SoftContactsParams(K=K, D=D)
# Define the nominal radius of the sphere
nominal_height = model.kin_dyn_parameters.hw_link_metadata.geometry[0, 2]
# Define scaling parameters
scaling_parameters = ScalingFactors(
dims=jnp.ones((model.number_of_links(), 3)) * scaling_factor,
density=jnp.array([1.0]),
)
# Update the model with the scaling parameters
updated_model = js.model.update_hw_parameters(model, scaling_parameters)
# Compute the expected height (nominal radius * scaling factor)
expected_height = nominal_height * scaling_factor / 2
# Simulate the box falling under gravity
data = js.data.JaxSimModelData.build(
model=updated_model,
# Set the initial position of the box's base to be slightly above the ground
# to allow it to settle at the expected height after scaling.
# The base position is set to the nominal height of the box scaled by the scaling factor,
# plus a small offset to avoid immediate collision with the ground.
# This ensures that the box has enough space to fall and settle at the expected height.
base_position=jnp.array(
[
*jax.random.uniform(subkey, shape=(2,)),
expected_height + 0.05,
]
),
)
num_steps = 1000 # Number of simulation steps
for _ in range(num_steps):
data = js.model.step(
model=updated_model,
data=data,
)
# Get the final height of the box's base
updated_base_height = data.base_position[2]
# Assert that the box settles at the expected height
assert jnp.isclose(
updated_base_height, expected_height, atol=1e-3
), f"model base height mismatch: expected {expected_height}, got {updated_base_height}"
def test_unsupported_link_cases():
"""
Test that unsupported link cases are handled correctly.
"""
import rod.builder.primitives
from jaxsim.api.kin_dyn_parameters import LinkParametrizableShape
# Test unsupported (no visual)
no_visual_model = js.model.JaxSimModel.build_from_model_description(
rod.builder.primitives.BoxBuilder(x=1, y=1, z=1, mass=1, name="no_vis_box")
.build_model()
.add_link(name="no_visual_link")
.add_inertial()
.build() # No .add_visual()
)
no_visual_metadata = no_visual_model.kin_dyn_parameters.hw_link_metadata
empty_metadata = HwLinkMetadata.empty()
comparison = jax.tree.map(
jnp.allclose,
no_visual_metadata,
empty_metadata,
)
assert jax.tree.reduce(
lambda acc, value: acc and bool(value), comparison, True
), "No links should be supported."
# Create a simple multi-link URDF and add collision to ensure links are kept
multi_link_urdf = """
"""
# Build JaxSim model from the URDF
multi_link_model = js.model.JaxSimModel.build_from_model_description(
multi_link_urdf, is_urdf=True
)
multi_link_metadata = multi_link_model.kin_dyn_parameters.hw_link_metadata
# Verify array consistency for the model
num_links = multi_link_model.number_of_links()
assert num_links == 3, f"Expected 3 links in the URDF model, got {num_links}"
assert (
len(multi_link_metadata.link_shape)
== len(multi_link_metadata.geometry)
== len(multi_link_metadata.density)
== num_links
)
# Count verification in single model
supported_count = sum(
1
for s in multi_link_metadata.link_shape
if s != LinkParametrizableShape.Unsupported
)
unsupported_count = sum(
1
for s in multi_link_metadata.link_shape
if s == LinkParametrizableShape.Unsupported
)
assert (
supported_count == 2
), f"Expected 2 supported links in single model, got {supported_count}"
assert (
unsupported_count == 1
), f"Expected 1 unsupported link in single model, got {unsupported_count}"
# Ensure shapes match expectations by name
link_indices = {name: idx for idx, name in enumerate(multi_link_model.link_names())}
assert (
multi_link_metadata.link_shape[link_indices["supported_link"]]
== LinkParametrizableShape.Box
), "Supported link should remain a box"
assert (
multi_link_metadata.link_shape[link_indices["unsupported_link"]]
== LinkParametrizableShape.Unsupported
), "Unsupported link should remain unsupported"
double_visual_idx = link_indices["double_visual_link"]
assert (
multi_link_metadata.link_shape[double_visual_idx]
== LinkParametrizableShape.Sphere
), "Double visual link should pick the first (sphere) visual"
assert_allclose(
multi_link_metadata.geometry[double_visual_idx, 0],
0.4,
err_msg="Sphere radius must match the first visual",
)
# Test selective parametrization: only 'supported_link' and 'double_visual_link' should be parametrized
selective_model = js.model.JaxSimModel.build_from_model_description(
multi_link_urdf, is_urdf=True, parametrized_links=("double_visual_link")
)
selective_metadata = selective_model.kin_dyn_parameters.hw_link_metadata
# Check that only the selected links are parametrized
link_indices = {name: idx for idx, name in enumerate(selective_model.link_names())}
assert (
selective_metadata.link_shape[link_indices["supported_link"]]
== LinkParametrizableShape.Unsupported
), "Selected supported_link should be parametrized as Box"
assert (
selective_metadata.link_shape[link_indices["double_visual_link"]]
== LinkParametrizableShape.Sphere
), "Selected double_visual_link should be parametrized as Sphere"
assert (
selective_metadata.link_shape[link_indices["unsupported_link"]]
== LinkParametrizableShape.Unsupported
), "Non-selected unsupported_link should be marked as Unsupported"
def test_export_continuous_joint_handling():
"""
Test that continuous joints are correctly exported with type="continuous"
and without position limits, while preserving effort and velocity limits.
"""
# Load cartpole model which has a continuous joint (pivot)
cartpole_path = (
pathlib.Path(__file__).parent.parent / "examples" / "assets" / "cartpole.urdf"
)
model = js.model.JaxSimModel.build_from_model_description(cartpole_path)
# Define some simple scaling parameters (identity scaling)
scaling_parameters = ScalingFactors(
dims=jnp.ones((model.number_of_links(), 3)),
density=jnp.ones(model.number_of_links()),
)
# Update the model with scaling parameters
updated_model = js.model.update_hw_parameters(model, scaling_parameters)
# Export the updated model
exported_urdf = updated_model.export_updated_model()
# Parse the URDF XML directly (not through rod, which would convert continuous back to revolute)
root = ET.fromstring(exported_urdf)
# Find the pivot joint (continuous joint)
pivot_joint = None
for joint_elem in root.findall(".//joint"):
if joint_elem.get("name") == "pivot":
pivot_joint = joint_elem
break
assert pivot_joint is not None, "pivot joint should exist in exported model"
# Verify that the joint type is "continuous"
assert (
pivot_joint.get("type") == "continuous"
), f"pivot joint should have type='continuous', got '{pivot_joint.get('type')}'"
# Verify that position limits are not present for continuous joints
limit_elem = pivot_joint.find("limit")
assert limit_elem is not None, "pivot joint should have limits element"
assert (
limit_elem.get("lower") is None
), f"continuous joint should not have lower position limit, got {limit_elem.get('lower')}"
assert (
limit_elem.get("upper") is None
), f"continuous joint should not have upper position limit, got {limit_elem.get('upper')}"
# Verify that effort and velocity limits are preserved
assert (
limit_elem.get("effort") is not None
), "continuous joint should preserve effort limit"
assert (
limit_elem.get("velocity") is not None
), "continuous joint should preserve velocity limit"
# Verify that the linear joint (prismatic) is NOT changed to continuous
linear_joint = None
for joint_elem in root.findall(".//joint"):
if joint_elem.get("name") == "linear":
linear_joint = joint_elem
break
assert linear_joint is not None, "linear joint should exist in exported model"
assert (
linear_joint.get("type") == "prismatic"
), f"linear joint should remain prismatic, got '{linear_joint.get('type')}'"
# Prismatic joint should keep its limits
linear_limit = linear_joint.find("limit")
assert linear_limit is not None, "prismatic joint should have limits"
assert (
linear_limit.get("lower") is not None
), "prismatic joint should have lower limit"
assert (
linear_limit.get("upper") is not None
), "prismatic joint should have upper limit"
def test_export_model_with_missing_collision(
jaxsim_model_missing_collision: js.model.JaxSimModel,
):
"""
Test that export_updated_model() works correctly when a link has a visual
but is missing a collision element.
This validates the skip logic that handles None collision elements.
"""
model = jaxsim_model_missing_collision
# Define scaling parameters to modify the model
scaling_parameters = ScalingFactors(
dims=jnp.array(
[
[1.5, 1.0, 1.0], # Scale x-dimension for link1 (has collision)
[2.0, 1.0, 1.0], # Scale radius for link2 (missing collision)
]
),
density=jnp.ones(2),
)
# Update the model with scaling parameters
updated_model = js.model.update_hw_parameters(model, scaling_parameters)
# Export the updated model - this should NOT fail even though link2 is missing collision
exported_urdf = updated_model.export_updated_model()
# Verify basic structure of exported URDF
assert isinstance(exported_urdf, str), "Exported URDF should be a string"
assert len(exported_urdf) > 0, "Exported URDF should not be empty"
# Parse the exported URDF to verify it's valid XML
root = ET.fromstring(exported_urdf)
assert root.tag == "robot", "Root element should be 'robot'"
# Find both links in the exported model
links = {link.get("name"): link for link in root.findall(".//link")}
assert "link1" in links, "link1 should exist in exported model"
assert "link2" in links, "link2 should exist in exported model"
# Verify link1 has both visual and collision
link1 = links["link1"]
link1_visual = link1.find("visual")
link1_collision = link1.find("collision")
assert link1_visual is not None, "link1 should have a visual element"
assert link1_collision is not None, "link1 should have a collision element"
# Verify link1's geometry was updated
link1_visual_box = link1_visual.find(".//box")
assert link1_visual_box is not None, "link1 visual should have box geometry"
link1_size = [float(x) for x in link1_visual_box.get("size").split()]
# First dimension should be scaled by 1.5
assert_allclose(
link1_size[0],
0.3 * 1.5,
atol=1e-6,
err_msg="link1 x-dimension should be scaled",
)
# Verify link2 has visual but no collision
link2 = links["link2"]
link2_visual = link2.find("visual")
link2_collision = link2.find("collision")
assert link2_visual is not None, "link2 should have a visual element"
assert link2_collision is None, "link2 should NOT have a collision element"
# Verify link2's visual geometry was updated despite missing collision
link2_visual_sphere = link2_visual.find(".//sphere")
assert link2_visual_sphere is not None, "link2 visual should have sphere geometry"
link2_radius = float(link2_visual_sphere.get("radius"))
# Radius should be scaled by 2.0
assert_allclose(
link2_radius,
0.1 * 2.0,
atol=1e-6,
err_msg="link2 radius should be scaled despite missing collision",
)
# Load the exported model to verify it can be parsed correctly
exported_sdf = rod.Sdf.load(exported_urdf, is_urdf=True)
assert (
len(exported_sdf.models()) == 1
), "Exported model should contain exactly one ROD model"
exported_model = exported_sdf.models()[0]
assert exported_model.name == model.name(), "Exported model name should match"
# Verify we can build a JaxSim model from the exported URDF
_ = js.model.JaxSimModel.build_from_model_description(
model_description=exported_urdf, is_urdf=True
)
def test_export_mesh_scaling_preserves_nonzero_visual_and_joint_origins(
tmp_path: pathlib.Path,
):
"""
Regression test for mesh export:
non-identity scaling must preserve non-zero visual/joint origins in the URDF.
"""
mesh_file = pathlib.Path(__file__).parent / "assets" / "cube.stl"
if not mesh_file.exists():
pytest.skip(f"Test mesh file not found: {mesh_file}")
urdf_path = tmp_path / "mesh_origin_regression.urdf"
urdf_path.write_text(
f"""
""",
encoding="utf-8",
)
model = js.model.JaxSimModel.build_from_model_description(
model_description=urdf_path,
is_urdf=True,
parametrized_links=("mesh_link",),
)
mesh_link_idx = js.link.name_to_idx(model=model, link_name="mesh_link")
dims = jnp.ones((model.number_of_links(), 3))
dims = dims.at[mesh_link_idx].set(jnp.array([1.7, 0.8, 1.3]))
scaling = ScalingFactors(dims=dims, density=jnp.ones(model.number_of_links()))
updated_model = js.model.update_hw_parameters(model=model, scaling_factors=scaling)
exported_urdf = updated_model.export_updated_model()
root = ET.fromstring(exported_urdf)
visual_origin = root.find(".//link[@name='mesh_link']/visual/origin")
assert visual_origin is not None, "Mesh visual origin must exist in exported URDF"
visual_xyz = np.array([float(v) for v in visual_origin.get("xyz").split()])
expected_visual_xyz = np.array(
updated_model.kin_dyn_parameters.hw_link_metadata.L_H_vis[mesh_link_idx][:3, 3]
)
assert_allclose(
visual_xyz,
expected_visual_xyz,
atol=1e-8,
err_msg="Exported mesh visual origin does not match updated metadata",
)
assert not np.allclose(visual_xyz, np.zeros(3), atol=1e-12)
joint_origin = root.find(".//joint[@name='base_to_mesh']/origin")
assert joint_origin is not None, "Joint origin must exist in exported URDF"
joint_xyz = np.array([float(v) for v in joint_origin.get("xyz").split()])
joint_idx = js.joint.name_to_idx(model=updated_model, joint_name="base_to_mesh")
expected_joint_xyz = np.array(
updated_model.kin_dyn_parameters.joint_model.λ_H_pre[joint_idx + 1][:3, 3]
)
assert_allclose(
joint_xyz,
expected_joint_xyz,
atol=1e-8,
err_msg="Exported joint origin does not match updated joint transform",
)
assert not np.allclose(joint_xyz, np.zeros(3), atol=1e-12)
reimported_jaxsim_model = js.model.JaxSimModel.build_from_model_description(
model_description=exported_urdf, is_urdf=True
)
assert (
reimported_jaxsim_model is not None
), "Should be able to build model from exported URDF"
assert (
reimported_jaxsim_model.number_of_links() == model.number_of_links()
), "Reimported model should have same number of links"
# =============================================================================
# Mesh Scaling Tests
# =============================================================================
def test_mesh_shape_enum():
"""Test that the Mesh shape type is available in the enum."""
assert hasattr(LinkParametrizableShape, "Mesh")
assert LinkParametrizableShape.Mesh == 3
def test_mixed_shapes_metadata():
"""Test loading and metadata verification for mixed primitive and mesh shapes."""
test_urdf = pathlib.Path(__file__).parent / "assets" / "mixed_shapes_robot.urdf"
if not test_urdf.exists():
pytest.skip(f"Test URDF not found: {test_urdf}")
mesh_file = pathlib.Path(__file__).parent / "assets" / "cube.stl"
if not mesh_file.exists():
pytest.skip(f"Test mesh not found: {mesh_file}")
# Load model with all link types parametrized
model = js.model.JaxSimModel.build_from_model_description(
model_description=test_urdf,
is_urdf=True,
parametrized_links=("box_link", "cylinder_link", "mesh_link", "sphere_link"),
)
assert model.name() == "mixed_shapes_robot"
assert model.number_of_links() == 4
hw_meta = model.kin_dyn_parameters.hw_link_metadata
# Verify all 4 links are parametrized with correct shape types
assert len(hw_meta.link_shape) == 4
assert hw_meta.link_shape[0] == LinkParametrizableShape.Box
assert hw_meta.link_shape[1] == LinkParametrizableShape.Cylinder
assert hw_meta.link_shape[2] == LinkParametrizableShape.Mesh
assert hw_meta.link_shape[3] == LinkParametrizableShape.Sphere
# Verify mesh data exists only for mesh link
assert hw_meta.mesh_vertices is not None
assert hw_meta.mesh_vertices[0] is None # box
assert hw_meta.mesh_vertices[1] is None # cylinder
assert hw_meta.mesh_vertices[2] is not None # mesh
assert hw_meta.mesh_vertices[3] is None # sphere
assert hw_meta.mesh_faces is not None
assert hw_meta.mesh_faces[2] is not None # mesh link has faces
def test_mixed_shapes_scaling():
"""Test uniform and non-uniform scaling with mixed primitive and mesh shapes."""
test_urdf = pathlib.Path(__file__).parent / "assets" / "mixed_shapes_robot.urdf"
if not test_urdf.exists():
pytest.skip(f"Test URDF not found: {test_urdf}")
mesh_file = pathlib.Path(__file__).parent / "assets" / "cube.stl"
if not mesh_file.exists():
pytest.skip(f"Test mesh not found: {mesh_file}")
model = js.model.JaxSimModel.build_from_model_description(
model_description=test_urdf,
is_urdf=True,
parametrized_links=("box_link", "cylinder_link", "mesh_link", "sphere_link"),
)
hw_meta = model.kin_dyn_parameters.hw_link_metadata
if len(hw_meta.link_shape) == 0:
pytest.skip("Hardware parametrization not supported")
# Get original masses
masses_orig = {}
for i in range(model.number_of_links()):
link_name = js.link.idx_to_name(model=model, link_index=i)
masses_orig[link_name] = float(model.kin_dyn_parameters.link_parameters.mass[i])
# Test uniform scaling (2x), so all links should scaled by 8x
uniform_scaling = ScalingFactors(
dims=jnp.ones((4, 3)) * 2.0,
density=jnp.ones(4),
)
scaled_uniform = js.model.update_hw_parameters(model, uniform_scaling)
for i in range(scaled_uniform.number_of_links()):
link_name = js.link.idx_to_name(model=scaled_uniform, link_index=i)
mass_scaled = float(scaled_uniform.kin_dyn_parameters.link_parameters.mass[i])
ratio = mass_scaled / masses_orig[link_name]
assert jnp.allclose(
ratio, 8.0, rtol=0.1
), f"Uniform scaling: {link_name} expected 8x, got {ratio:.2f}x"
# Test different scaling factors per link
different_scaling = ScalingFactors(
dims=jnp.array(
[
[2.0, 2.0, 2.0], # box: 8x
[3.0, 3.0, 3.0], # cylinder: 27x
[1.5, 1.5, 1.5], # mesh: 3.375x
[2.5, 2.5, 2.5], # sphere: 15.625x
]
),
density=jnp.ones(4),
)
scaled_different = js.model.update_hw_parameters(model, different_scaling)
expected_ratios = {
"box_link": 8.0,
"cylinder_link": 27.0,
"mesh_link": 3.375,
"sphere_link": 15.625,
}
for i in range(scaled_different.number_of_links()):
link_name = js.link.idx_to_name(model=scaled_different, link_index=i)
mass_scaled = float(scaled_different.kin_dyn_parameters.link_parameters.mass[i])
ratio = mass_scaled / masses_orig[link_name]
expected = expected_ratios[link_name]
assert jnp.allclose(
ratio, expected, rtol=0.1
), f"Different scaling: {link_name} expected {expected}x, got {ratio:.2f}x"
================================================
FILE: tests/test_automatic_differentiation.py
================================================
import os
import jax
import jax.numpy as jnp
import numpy as np
from jax.test_util import check_grads
import jaxsim.api as js
import jaxsim.math
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim import VelRepr
from jaxsim.rbda.contacts import SoftContacts, SoftContactsParams
from .utils import assert_allclose
# All JaxSim algorithms, excluding the variable-step integrators, should support
# being automatically differentiated until second order, both in FWD and REV modes.
# However, checking the second-order derivatives is particularly slow and makes
# CI tests take too long. Therefore, we only check first-order derivatives.
AD_ORDER = os.environ.get("JAXSIM_TEST_AD_ORDER", 1)
# Define the step size used to compute finite differences depending on the
# floating point resolution.
ε = os.environ.get(
"JAXSIM_TEST_FD_STEP_SIZE",
jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3),
)
def get_random_data_and_references(
model: js.model.JaxSimModel,
velocity_representation: VelRepr,
key: jax.Array,
) -> tuple[js.data.JaxSimModelData, js.references.JaxSimModelReferences]:
key, subkey = jax.random.split(key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
_, subkey1, subkey2 = jax.random.split(key, num=3)
references = js.references.JaxSimModelReferences.build(
model=model,
joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)),
link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)),
data=data,
velocity_representation=velocity_representation,
)
# Remove the force applied to the base link if the model is fixed-base.
if not model.floating_base():
references = references.apply_link_forces(
forces=jnp.atleast_2d(jnp.zeros(6)),
model=model,
data=data,
link_names=(model.base_link(),),
additive=False,
)
return data, references
def test_ad_aba(
jaxsim_models_types: js.model.JaxSimModel,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data, references = get_random_data_and_references(
model=model, velocity_representation=VelRepr.Inertial, key=subkey
)
# Get the standard gravity constant.
g = jaxsim.math.STANDARD_GRAVITY
# State in VelRepr.Inertial representation.
W_p_B = data.base_position
W_Q_B = data.base_orientation
s = data.joint_positions
W_v_WB = data.base_velocity
ṡ = data.joint_velocities
# Inputs.
W_f_L = references.link_forces(model=model)
τ = references.joint_force_references(model=model)
# ====
# Test
# ====
# Get a closure exposing only the parameters to be differentiated.
def aba(W_p_B, W_Q_B, s, W_v_WB, ṡ, τ, W_f_L, g):
import jaxlie
W_H_B = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3(wxyz=W_Q_B / jnp.linalg.norm(W_Q_B)),
translation=W_p_B,
).as_matrix()
joint_transforms = model.kin_dyn_parameters.joint_transforms(
joint_positions=s, base_transform=W_H_B
)
return jaxsim.rbda.aba(
model=model,
base_position=W_p_B,
base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),
joint_positions=s,
base_linear_velocity=W_v_WB[0:3],
base_angular_velocity=W_v_WB[3:6],
joint_velocities=ṡ,
joint_transforms=joint_transforms,
joint_forces=τ,
link_forces=W_f_L,
standard_gravity=g,
)
# Check derivatives against finite differences.
check_grads(
f=aba,
args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, τ, W_f_L, g),
order=AD_ORDER,
modes=["rev", "fwd"],
eps=ε,
)
def test_ad_aba_parallel(
jaxsim_models_types: js.model.JaxSimModel,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data, references = get_random_data_and_references(
model=model, velocity_representation=VelRepr.Inertial, key=subkey
)
g = jaxsim.math.STANDARD_GRAVITY
W_p_B = data.base_position
W_Q_B = data.base_orientation
s = data.joint_positions
W_v_WB = data.base_velocity
ṡ = data.joint_velocities
i_X_λi = data._joint_transforms
W_f_L = references.link_forces(model=model)
τ = references.joint_force_references(model=model)
# Verify parallel ABA matches sequential ABA.
result_seq = jaxsim.rbda.aba(
model=model,
base_position=W_p_B,
base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),
joint_positions=s,
base_linear_velocity=W_v_WB[0:3],
base_angular_velocity=W_v_WB[3:6],
joint_velocities=ṡ,
joint_transforms=i_X_λi,
joint_forces=τ,
link_forces=W_f_L,
standard_gravity=g,
)
result_par = jaxsim.rbda.aba_parallel(
model=model,
base_position=W_p_B,
base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),
joint_positions=s,
base_linear_velocity=W_v_WB[0:3],
base_angular_velocity=W_v_WB[3:6],
joint_velocities=ṡ,
joint_forces=τ,
joint_transforms=i_X_λi,
link_forces=W_f_L,
standard_gravity=g,
)
assert_allclose(result_seq[0], result_par[0], atol=1e-10)
assert_allclose(result_seq[1], result_par[1], atol=1e-10)
# Check derivatives against finite differences.
aba_par = lambda W_p_B, W_Q_B, s, W_v_WB, ṡ, i_X_λi, τ, W_f_L, g: jaxsim.rbda.aba_parallel(
model=model,
base_position=W_p_B,
base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),
joint_positions=s,
base_linear_velocity=W_v_WB[0:3],
base_angular_velocity=W_v_WB[3:6],
joint_velocities=ṡ,
joint_forces=τ,
joint_transforms=i_X_λi,
link_forces=W_f_L,
standard_gravity=g,
)
check_grads(
f=aba_par,
args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, i_X_λi, τ, W_f_L, g),
order=AD_ORDER,
modes=["rev", "fwd"],
eps=ε,
)
def test_ad_rnea(
jaxsim_models_types: js.model.JaxSimModel,
prng_key: jax.Array,
):
model = jaxsim_models_types
key, subkey = jax.random.split(prng_key, num=2)
data, references = get_random_data_and_references(
model=model, velocity_representation=VelRepr.Inertial, key=subkey
)
# Get the standard gravity constant.
g = jaxsim.math.STANDARD_GRAVITY
# State in VelRepr.Inertial representation.
W_p_B = data.base_position
W_Q_B = data.base_orientation
s = data.joint_positions
W_v_WB = data.base_velocity
ṡ = data.joint_velocities
i_X_λi = data._joint_transforms
# Inputs.
W_f_L = references.link_forces(model=model)
# ====
# Test
# ====
_, subkey1, subkey2 = jax.random.split(key, num=3)
W_v̇_WB = jax.random.uniform(subkey1, shape=(6,), minval=-1)
s̈ = jax.random.uniform(subkey2, shape=(model.dofs(),), minval=-1)
# Get a closure exposing only the parameters to be differentiated.
rnea = lambda W_p_B, W_Q_B, s, W_v_WB, ṡ, W_v̇_WB, s̈, i_X_λi, W_f_L, g: jaxsim.rbda.rnea(
model=model,
base_position=W_p_B,
base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),
joint_positions=s,
base_linear_velocity=W_v_WB[0:3],
base_angular_velocity=W_v_WB[3:6],
joint_velocities=ṡ,
base_linear_acceleration=W_v̇_WB[0:3],
base_angular_acceleration=W_v̇_WB[3:6],
joint_accelerations=s̈,
joint_transforms=i_X_λi,
link_forces=W_f_L,
standard_gravity=g,
)
# Check derivatives against finite differences.
check_grads(
f=rnea,
args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, W_v̇_WB, s̈, i_X_λi, W_f_L, g),
order=AD_ORDER,
modes=["rev", "fwd"],
eps=ε,
)
def test_ad_crba(
jaxsim_models_types: js.model.JaxSimModel,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data, _ = get_random_data_and_references(
model=model, velocity_representation=VelRepr.Inertial, key=subkey
)
# State in VelRepr.Inertial representation.
s = data.joint_positions
# ====
# Test
# ====
# Get a closure exposing only the parameters to be differentiated.
crba = lambda s: jaxsim.rbda.crba(model=model, joint_positions=s)
# Check derivatives against finite differences.
check_grads(
f=crba,
args=(s,),
order=AD_ORDER,
modes=["rev", "fwd"],
eps=ε,
)
def test_ad_fk(
jaxsim_models_types: js.model.JaxSimModel,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data, _ = get_random_data_and_references(
model=model, velocity_representation=VelRepr.Inertial, key=subkey
)
# State in VelRepr.Inertial representation.
W_p_B = data.base_position
W_Q_B = data.base_orientation
s = data.joint_positions
W_v_lin = data._base_linear_velocity
W_v_ang = data._base_angular_velocity
ṡ = data.joint_velocities
# ====
# Test
# ====
# Get a closure exposing only the parameters to be differentiated.
def fk(W_p_B, W_Q_B, s, W_v_lin, W_v_ang, ṡ):
import jaxlie
W_H_B = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3(wxyz=W_Q_B / jnp.linalg.norm(W_Q_B)),
translation=W_p_B,
).as_matrix()
joint_transforms = model.kin_dyn_parameters.joint_transforms(
joint_positions=s, base_transform=W_H_B
)
return jaxsim.rbda.forward_kinematics_model(
model=model,
base_position=W_p_B,
base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),
joint_positions=s,
base_linear_velocity_inertial=W_v_lin,
base_angular_velocity_inertial=W_v_ang,
joint_velocities=ṡ,
joint_transforms=joint_transforms,
)
# Check derivatives against finite differences.
check_grads(
f=fk,
args=(W_p_B, W_Q_B, s, W_v_lin, W_v_ang, ṡ),
order=AD_ORDER,
modes=["rev", "fwd"],
eps=ε,
)
def test_ad_jacobian(
jaxsim_models_types: js.model.JaxSimModel,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data, _ = get_random_data_and_references(
model=model, velocity_representation=VelRepr.Inertial, key=subkey
)
# State in VelRepr.Inertial representation.
s = data.joint_positions
# ====
# Test
# ====
# Get the link indices.
link_indices = jnp.arange(model.number_of_links())
# Get a closure exposing only the parameters to be differentiated.
# We differentiate the jacobian of the last link, likely among those
# farther from the base.
jacobian = lambda s: jaxsim.rbda.jacobian(
model=model, joint_positions=s, link_index=link_indices[-1]
)
# Check derivatives against finite differences.
check_grads(
f=jacobian,
args=(s,),
order=AD_ORDER,
modes=["rev", "fwd"],
eps=ε,
)
def test_ad_soft_contacts(
jaxsim_models_types: js.model.JaxSimModel,
prng_key: jax.Array,
):
model = jaxsim_models_types
with model.editable(validate=False) as model:
model.contact_model = jaxsim.rbda.contacts.SoftContacts.build()
_, subkey1, subkey2, subkey3 = jax.random.split(prng_key, num=4)
p = jax.random.uniform(subkey1, shape=(3,), minval=-1)
v = jax.random.uniform(subkey2, shape=(3,), minval=-1)
m = jax.random.uniform(subkey3, shape=(3,), minval=-1)
# Get the soft contacts parameters.
parameters = js.contact.estimate_good_contact_parameters(model=model)
# ====
# Test
# ====
# Get a closure exposing only the parameters to be differentiated.
def close_over_inputs_and_parameters(
p: jtp.VectorLike,
v: jtp.VectorLike,
m: jtp.VectorLike,
params: SoftContactsParams,
) -> tuple[jtp.Vector, jtp.Vector]:
W_f_Ci, CW_ṁ = SoftContacts.compute_contact_force(
position=p,
velocity=v,
tangential_deformation=m,
parameters=params,
terrain=model.terrain,
)
return W_f_Ci, CW_ṁ
# Check derivatives against finite differences.
check_grads(
f=close_over_inputs_and_parameters,
args=(p, v, m, parameters),
order=AD_ORDER,
modes=["rev", "fwd"],
eps=ε,
# On GPU, the tolerance needs to be increased.
rtol=0.02 if "gpu" in {d.platform for d in p.devices()} else None,
)
def test_ad_integration(
jaxsim_models_types: js.model.JaxSimModel,
prng_key: jax.Array,
):
model = jaxsim_models_types
_, subkey = jax.random.split(prng_key, num=2)
data, references = get_random_data_and_references(
model=model, velocity_representation=VelRepr.Inertial, key=subkey
)
# State in VelRepr.Inertial representation.
W_p_B = data.base_position
W_Q_B = data.base_orientation
s = data.joint_positions
W_v_WB = data.base_velocity
ṡ = data.joint_velocities
# Inputs.
W_f_L = references.link_forces(model=model)
τ = references.joint_force_references(model=model)
# ====
# Test
# ====
# Function exposing only the parameters to be differentiated.
def step(
W_p_B: jtp.Vector,
W_Q_B: jtp.Vector,
s: jtp.Vector,
W_v_WB: jtp.Vector,
ṡ: jtp.Vector,
τ: jtp.Vector,
W_f_L: jtp.Matrix,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]:
# When JAX tests against finite differences, the injected ε will make the
# quaternion non-unitary, which will cause the AD check to fail.
W_Q_B = W_Q_B / jnp.linalg.norm(W_Q_B)
data_x0 = data.replace(
model=model,
base_position=W_p_B,
base_quaternion=W_Q_B,
joint_positions=s,
base_linear_velocity=W_v_WB[0:3],
base_angular_velocity=W_v_WB[3:6],
joint_velocities=ṡ,
)
data_xf = js.model.step(
model=model,
data=data_x0,
joint_force_references=τ,
link_forces=W_f_L,
)
xf_W_p_B = data_xf.base_position
xf_W_Q_B = data_xf.base_orientation
xf_s = data_xf.joint_positions
xf_W_v_WB = data_xf.base_velocity
xf_ṡ = data_xf.joint_velocities
return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ
# Check derivatives against finite differences.
check_grads(
f=step,
args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, τ, W_f_L),
order=AD_ORDER,
modes=["fwd", "rev"],
eps=ε,
)
def test_ad_safe_norm(
prng_key: jax.Array,
):
_, subkey = jax.random.split(prng_key, num=2)
array = jax.random.uniform(subkey, shape=(4,), minval=-5, maxval=5)
# ====
# Test
# ====
# Test that the safe_norm function is compatible with batching.
array = jnp.stack([array, array])
assert jaxsim.math.safe_norm(array, axis=-1).shape == (2,)
# Test that the safe_norm function is correctly computing the norm.
assert_allclose(
jaxsim.math.safe_norm(array, axis=-1), np.linalg.norm(array, axis=-1)
)
# Function exposing only the parameters to be differentiated.
def safe_norm(array: jtp.Array) -> jtp.Array:
return jaxsim.math.safe_norm(array, axis=-1)
# Check derivatives against finite differences.
check_grads(
f=safe_norm,
args=(array,),
order=AD_ORDER,
modes=["rev", "fwd"],
eps=ε,
)
# Check derivatives against finite differences when the array is zero.
check_grads(
f=safe_norm,
args=(jnp.zeros_like(array),),
order=AD_ORDER,
modes=["rev", "fwd"],
eps=ε,
)
def test_ad_hw_parameters(
jaxsim_model_garpez: js.model.JaxSimModel,
prng_key: jax.Array,
):
"""
Test the automatic differentiation capability for hardware parameters of the model links.
"""
model = jaxsim_model_garpez
data = js.data.JaxSimModelData.build(model=model)
min_val, max_val = 0.5, 10.0
# Generate random scaling factors for testing.
_, subkey1, subkey2 = jax.random.split(prng_key, num=3)
dims_scaling = jax.random.uniform(
subkey1, shape=(model.number_of_links(), 3), minval=min_val, maxval=max_val
)
density_scaling = jax.random.uniform(
subkey2, shape=(model.number_of_links(),), minval=min_val, maxval=max_val
)
scaling_factors = js.kin_dyn_parameters.ScalingFactors(
dims=dims_scaling, density=density_scaling
)
link_idx = js.link.name_to_idx(model, link_name="link4")
# Define a function that updates hardware parameters and computes FK for link 4.
def update_hw_params_and_compute_fk_and_mass(
scaling_factors: js.kin_dyn_parameters.ScalingFactors,
):
# Update hardware parameters.
updated_model = js.model.update_hw_parameters(
model=model, scaling_factors=scaling_factors
)
# Compute forward kinematics for link 4.
W_H_L4 = js.model.forward_kinematics(model=updated_model, data=data)[link_idx]
# Compute the free floating mass matrix of the updated model.
M = js.model.free_floating_mass_matrix(updated_model, data)
return W_H_L4[:3, 3], M
# Check derivatives against finite differences.
check_grads(
f=update_hw_params_and_compute_fk_and_mass,
args=(scaling_factors,),
order=AD_ORDER,
modes=["fwd", "rev"],
eps=ε,
)
================================================
FILE: tests/test_benchmark.py
================================================
from collections.abc import Callable
import jax
import jax.numpy as jnp
import pytest
import jaxsim
import jaxsim.api as js
from jaxsim.api.kin_dyn_parameters import ScalingFactors
def vectorize_data(model: js.model.JaxSimModel, batch_size: int):
key = jax.random.PRNGKey(seed=0)
keys = jax.random.split(key, num=batch_size)
return jax.vmap(
lambda key: js.data.random_model_data(
model=model,
key=key,
)
)(keys)
def benchmark_test_function(
func: Callable, model: js.model.JaxSimModel, benchmark, batch_size
):
"""Reusability wrapper for benchmark tests."""
data = vectorize_data(model=model, batch_size=batch_size)
# Warm-up call to avoid including compilation time
jax.vmap(func, in_axes=(None, 0))(model, data)
# Benchmark the function call
# Note: jax.block_until_ready is used to ensure that the benchmark is not measuring only the asynchronous dispatch
benchmark(jax.block_until_ready(jax.vmap(func, in_axes=(None, 0))), model, data)
@pytest.mark.benchmark
def test_forward_dynamics_aba(
jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size
):
model = jaxsim_model_ergocub_reduced
benchmark_test_function(js.model.forward_dynamics_aba, model, benchmark, batch_size)
@pytest.mark.benchmark
def test_free_floating_bias_forces(
jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size
):
model = jaxsim_model_ergocub_reduced
benchmark_test_function(
js.model.free_floating_bias_forces, model, benchmark, batch_size
)
@pytest.mark.benchmark
def test_forward_kinematics(
jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size
):
model = jaxsim_model_ergocub_reduced
benchmark_test_function(js.model.forward_kinematics, model, benchmark, batch_size)
@pytest.mark.benchmark
def test_free_floating_mass_matrix(
jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size
):
model = jaxsim_model_ergocub_reduced
benchmark_test_function(
js.model.free_floating_mass_matrix, model, benchmark, batch_size
)
@pytest.mark.benchmark
def test_free_floating_jacobian(
jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size
):
model = jaxsim_model_ergocub_reduced
benchmark_test_function(
js.model.generalized_free_floating_jacobian, model, benchmark, batch_size
)
@pytest.mark.benchmark
def test_free_floating_jacobian_derivative(
jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size
):
model = jaxsim_model_ergocub_reduced
benchmark_test_function(
js.model.generalized_free_floating_jacobian_derivative,
model,
benchmark,
batch_size,
)
@pytest.mark.benchmark
def test_soft_contact_model(
jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size
):
model = jaxsim_model_ergocub_reduced
with model.editable(validate=False) as model:
model.contact_model = jaxsim.rbda.contacts.SoftContacts()
model.contact_params = js.contact.estimate_good_contact_parameters(model=model)
benchmark_test_function(js.ode.system_dynamics, model, benchmark, batch_size)
@pytest.mark.benchmark
def test_rigid_contact_model(
jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size
):
model = jaxsim_model_ergocub_reduced
with model.editable(validate=False) as model:
model.contact_model = jaxsim.rbda.contacts.RigidContacts()
model.contact_params = js.contact.estimate_good_contact_parameters(model=model)
benchmark_test_function(js.ode.system_dynamics, model, benchmark, batch_size)
@pytest.mark.benchmark
def test_relaxed_rigid_contact_model(
jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size
):
model = jaxsim_model_ergocub_reduced
with model.editable(validate=False) as model:
model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts()
model.contact_params = js.contact.estimate_good_contact_parameters(model=model)
benchmark_test_function(js.ode.system_dynamics, model, benchmark, batch_size)
@pytest.mark.benchmark
def test_simulation_step(
jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size
):
model = jaxsim_model_ergocub_reduced
with model.editable(validate=False) as model:
model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts()
model.contact_params = js.contact.estimate_good_contact_parameters(model=model)
benchmark_test_function(js.model.step, model, benchmark, batch_size)
@pytest.mark.benchmark
def test_update_hw_parameters(
jaxsim_model_garpez: js.model.JaxSimModel, benchmark, batch_size
):
"""Benchmark hardware parameter scaling/update operation (vmapped)."""
model = jaxsim_model_garpez
n_links = model.number_of_links()
# Create a function that generates random scaling factors and updates the model
def update_with_random_scaling(key: jax.Array) -> js.model.JaxSimModel:
# Generate scaling factors in a reasonable range [0.8, 1.2]
dims_scale = jax.random.uniform(key, shape=(n_links, 3), minval=0.8, maxval=1.2)
density_scale = jax.random.uniform(
jax.random.fold_in(key, 1), shape=(n_links,), minval=0.8, maxval=1.2
)
scaling_factors = ScalingFactors(dims=dims_scale, density=density_scale)
return js.model.update_hw_parameters(model, scaling_factors)
# Generate batch of random keys
key = jax.random.PRNGKey(seed=42)
keys = jax.random.split(key, num=batch_size)
# Warm-up call to avoid including compilation time
jax.vmap(update_with_random_scaling)(keys)
# Benchmark the vmapped update operation
benchmark(jax.block_until_ready(jax.vmap(update_with_random_scaling)), keys)
@pytest.mark.benchmark
def test_export_updated_model(
jaxsim_model_garpez: js.model.JaxSimModel, benchmark, batch_size
):
"""Benchmark model export after hardware parameter update."""
model = jaxsim_model_garpez
n_links = model.number_of_links()
# Create multiple scaled models for benchmarking
# Each with slightly different scaling to simulate realistic scenarios
key = jax.random.PRNGKey(seed=42)
scaling_values = jax.random.uniform(
key, shape=(batch_size,), minval=0.9, maxval=1.2
)
updated_models = []
for scale_value in scaling_values:
scaling_factors = ScalingFactors(
dims=jnp.ones((n_links, 3)) * float(scale_value),
density=jnp.ones(n_links),
)
updated_models.append(js.model.update_hw_parameters(model, scaling_factors))
# Benchmark the export operation (sequentially for all models)
# Note: This is not JIT-compiled since it returns a string (URDF/SDF)
def export_all():
return [m.export_updated_model() for m in updated_models]
benchmark(export_all)
================================================
FILE: tests/test_exceptions.py
================================================
import io
from contextlib import redirect_stdout
import chex
import jax
import jax.numpy as jnp
import pytest
from jax.errors import JaxRuntimeError
from jaxsim import exceptions
def test_exceptions_in_jit_functions():
msg_during_jit = "Compiling jit_compiled_function"
@jax.jit
@chex.assert_max_traces(n=1)
def jit_compiled_function(data: jax.Array) -> jax.Array:
# This message is compiled only during JIT compilation.
print(msg_during_jit)
# Condition that will trigger the exception.
failed_if_42_plus = jnp.allclose(data, 42)
# Raise a ValueError if the condition is met.
# The fmt string is built from kwargs.
exceptions.raise_value_error_if(
condition=failed_if_42_plus,
msg="Raising ValueError since data={num}",
num=data,
)
# Condition that will trigger the exception.
failed_if_42_minus = jnp.allclose(data, -42)
# Raise a RuntimeError if the condition is met.
# The fmt string is built from args.
exceptions.raise_runtime_error_if(
failed_if_42_minus,
"Raising RuntimeError since data={}",
data,
)
return data
# In the first call, the function will be compiled and print the message.
with jax.log_compiles(), io.StringIO() as buf, redirect_stdout(buf):
data = 40
out = jit_compiled_function(data=data)
stdout = buf.getvalue()
assert out == data
assert msg_during_jit in stdout
# In the second call, the function won't be compiled and won't print the message.
with jax.log_compiles(), io.StringIO() as buf, redirect_stdout(buf):
data = 41
out = jit_compiled_function(data=data)
stdout = buf.getvalue()
assert out == data
assert msg_during_jit not in stdout
# Let's trigger a ValueError exception by passing 42.
data = 42
with pytest.raises(
JaxRuntimeError,
match=f"ValueError: Raising ValueError since data={data}",
):
_ = jit_compiled_function(data=data)
# Let's trigger a RuntimeError exception by passing -42.
data = -42
with pytest.raises(
JaxRuntimeError,
match=f"RuntimeError: Raising RuntimeError since data={data}",
):
_ = jit_compiled_function(data=data)
================================================
FILE: tests/test_meshes.py
================================================
import trimesh
from jaxsim.parsers.rod import meshes
def test_mesh_wrapping_vertex_extraction():
"""
Test the vertex extraction method on different meshes.
1. A simple box.
2. A sphere.
"""
# Test 1: A simple box.
# First, create a box with origin at (0,0,0) and extents (3,3,3),
# i.e. points span from -1.5 to 1.5 on the axis.
mesh = trimesh.creation.box(
extents=[3.0, 3.0, 3.0],
)
points = meshes.extract_points_vertices(mesh=mesh)
assert len(points) == len(mesh.vertices)
# Test 2: A sphere.
# The sphere is centered at the origin and has a radius of 1.0.
mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0)
points = meshes.extract_points_vertices(mesh=mesh)
assert len(points) == len(mesh.vertices)
def test_mesh_wrapping_aap():
"""
Test the AAP wrapping method on different meshes.
1. A simple box
1.1: Remove all points above x=0.0
1.2: Remove all points below y=0.0
2. A sphere
"""
# Test 1.1: Remove all points above x=0.0.
# The expected result is that the number of points is halved.
# First, create a box with origin at (0,0,0) and extents (3,3,3),
# i.e. points span from -1.5 to 1.5 on the axis.
mesh = trimesh.creation.box(extents=[3.0, 3.0, 3.0])
points = meshes.extract_points_aap(mesh=mesh, axis="x", lower=0.0)
assert len(points) == len(mesh.vertices) // 2
assert all(points[:, 0] > 0.0)
# Test 1.2: Remove all points below y=0.0.
# The expected result is that the number of points is halved.
points = meshes.extract_points_aap(mesh=mesh, axis="y", upper=0.0)
assert len(points) == len(mesh.vertices) // 2
assert all(points[:, 1] < 0.0)
# Test 2: A sphere.
# The sphere is centered at the origin and has a radius of 1.0.
# Points are expected to be halved.
mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0)
# Remove all points above y=0.0.
points = meshes.extract_points_aap(mesh=mesh, axis="y", lower=0.0)
assert all(points[:, 1] >= 0.0)
assert len(points) < len(mesh.vertices)
def test_mesh_wrapping_points_over_axis():
"""
Test the points over axis method on different meshes.
1. A simple box
1.1: Select 10 points from the lower end of the x-axis
1.2: Select 10 points from the higher end of the y-axis
2. A sphere
"""
# Test 1.1: Remove 10 points from the lower end of the x-axis.
# First, create a box with origin at (0,0,0) and extents (3,3,3),
# i.e. points span from -1.5 to 1.5 on the axis.
mesh = trimesh.creation.box(extents=[3.0, 3.0, 3.0])
points = meshes.extract_points_select_points_over_axis(
mesh=mesh, axis="x", direction="lower", n=4
)
assert len(points) == 4
assert all(points[:, 0] < 0.0)
# Test 1.2: Select 10 points from the higher end of the y-axis.
points = meshes.extract_points_select_points_over_axis(
mesh=mesh, axis="y", direction="higher", n=4
)
assert len(points) == 4
assert all(points[:, 1] > 0.0)
# Test 2: A sphere.
# The sphere is centered at the origin and has a radius of 1.0.
mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0)
sphere_n_vertices = len(mesh.vertices)
# Select 10 points from the higher end of the z-axis.
points = meshes.extract_points_select_points_over_axis(
mesh=mesh, axis="z", direction="higher", n=sphere_n_vertices // 2
)
assert len(points) == sphere_n_vertices // 2
assert all(points[:, 2] >= 0.0)
================================================
FILE: tests/test_pytree.py
================================================
import io
import pathlib
from contextlib import redirect_stdout
import chex
import jax
import jax.numpy as jnp
import pytest
import jaxsim.api as js
def test_call_jit_compiled_function_passing_different_objects(
ergocub_model_description_path: pathlib.Path, jaxsim_model_box
):
# Create a first model from the URDF.
ergocub_model1 = js.model.JaxSimModel.build_from_model_description(
model_description=ergocub_model_description_path
)
# Create a second model from the URDF.
ergocub_model2 = js.model.JaxSimModel.build_from_model_description(
model_description=ergocub_model_description_path
)
box_model = jaxsim_model_box
# The objects should be different, but the comparison should return True.
assert id(ergocub_model1) != id(ergocub_model2)
assert ergocub_model1 == ergocub_model2
assert hash(ergocub_model1) == hash(ergocub_model2)
# If this function has never been compiled by any other test, JAX will
# jit-compile it here.
_ = js.contact.estimate_good_contact_parameters(model=ergocub_model1)
# Now JAX should not compile it again.
with jax.log_compiles(), io.StringIO() as buf, redirect_stdout(buf):
# Beyond running without any JIT recompilations, the following function
# should work on different JaxSimModel objects without raising any errors
# related to the comparison of Static fields.
_ = js.contact.estimate_good_contact_parameters(model=ergocub_model2)
stdout = buf.getvalue()
assert (
f"Compiling {js.contact.estimate_good_contact_parameters.__name__}"
not in stdout
)
# Define a new JIT-compiled function and check that is not recompiled for
# different model objects having the same pytree structure.
@jax.jit
@chex.assert_max_traces(n=1)
def my_jit_function(model: js.model.JaxSimModel, data: js.data.JaxSimModelData):
# Return random elements from model and data, just to have something returned.
return (
jnp.sum(model.kin_dyn_parameters.link_parameters.mass),
data.base_position,
)
data1 = js.data.JaxSimModelData.build(model=ergocub_model1)
_ = my_jit_function(model=ergocub_model1, data=data1)
# This should not retrace the function, as ergocub_model2 has the same
# pytree structure as ergocub_model1.
_ = my_jit_function(model=ergocub_model2, data=data1)
# Calling the function with a different model object will retrace it, as
# expected. Therefore, an AssertionError should be raised.
with pytest.raises(
AssertionError, match="Function 'my_jit_function' is traced > 1 times!"
):
data3 = js.data.JaxSimModelData.build(model=box_model)
_ = my_jit_function(model=box_model, data=data3)
================================================
FILE: tests/test_simulations.py
================================================
import jax
import jax.numpy as jnp
import numpy as np
import pytest
import jaxsim.api as js
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim import VelRepr
from jaxsim.api.kin_dyn_parameters import ConstraintType
from .utils import assert_allclose
def test_box_with_external_forces(
jaxsim_model_box: js.model.JaxSimModel,
velocity_representation: VelRepr,
):
"""
Simulate a box falling due to gravity.
We apply to its CoM a 6D force that balances exactly the gravitational force.
The box should not fall.
"""
model = jaxsim_model_box
# Build the data of the model.
data0 = js.data.JaxSimModelData.build(
model=model,
base_position=jnp.array([0.0, 0.0, 0.5]),
velocity_representation=velocity_representation,
)
# Compute the force due to gravity at the CoM.
mg = -model.gravity * js.model.total_mass(model=model)
G_f = jnp.array([0.0, 0.0, mg, 0, 0, 0])
# Compute the position of the CoM expressed in the coordinates of the link frame L.
L_p_CoM = js.link.com_position(
model=model, data=data0, link_index=0, in_link_frame=True
)
# Compute the transform of 6D forces from the CoM to the link frame.
L_H_G = jaxsim.math.Transform.from_quaternion_and_translation(translation=L_p_CoM)
G_Xv_L = jaxsim.math.Adjoint.from_transform(transform=L_H_G, inverse=True)
L_Xf_G = G_Xv_L.T
L_f = L_Xf_G @ G_f
# Initialize a references object that simplifies handling external forces.
references = js.references.JaxSimModelReferences.build(
model=model,
data=data0,
velocity_representation=velocity_representation,
)
# Apply a link forces to the base link.
with references.switch_velocity_representation(VelRepr.Body):
references = references.apply_link_forces(
forces=jnp.atleast_2d(L_f),
link_names=model.link_names()[0:1],
model=model,
data=data0,
additive=False,
)
# Initialize the simulation horizon.
tf = 0.5
T_ns = jnp.arange(start=0, stop=tf * 1e9, step=model.time_step * 1e9, dtype=int)
# Copy the initial data...
data = data0.copy()
# ... and step the simulation.
for _ in T_ns:
data = js.model.step(
model=model,
data=data,
link_forces=references.link_forces(model, data),
)
# Check that the box didn't move.
assert_allclose(data.base_position, data0.base_position)
assert_allclose(data.base_orientation, data0.base_orientation)
def test_box_with_zero_gravity(
jaxsim_model_box: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jnp.ndarray,
):
model = jaxsim_model_box
# Move the terrain (almost) infinitely far away from the box.
with model.editable(validate=False) as model:
model.terrain = jaxsim.terrain.FlatTerrain.build(height=-1e9)
model.gravity = 0.0
# Split the PRNG key.
_, subkey = jax.random.split(prng_key, num=2)
# Build the data of the model.
data0 = js.data.JaxSimModelData.build(
model=model,
base_position=jax.random.uniform(subkey, shape=(3,)),
velocity_representation=velocity_representation,
)
# Initialize a references object that simplifies handling external forces.
references = js.references.JaxSimModelReferences.build(
model=model,
data=data0,
velocity_representation=velocity_representation,
)
# Apply a link forces to the base link.
with references.switch_velocity_representation(jaxsim.VelRepr.Mixed):
# Generate a random linear force.
# We enforce them to be the same for all velocity representations so that
# we can compare their outcomes.
LW_f = 10.0 * (
jax.random.uniform(jax.random.key(0), shape=(model.number_of_links(), 6))
.at[:, 3:]
.set(jnp.zeros(3))
)
# Note that the context manager does not switch back the newly created
# `references` (that is not the yielded object) to the original representation.
# In the simulation loop below, we need to make sure that we switch both `data`
# and `references` to the same representation before extracting the information
# passed to the step function.
references = references.apply_link_forces(
forces=jnp.atleast_2d(LW_f),
link_names=model.link_names(),
model=model,
data=data0,
additive=False,
)
tf = 0.01
T = jnp.arange(start=0, stop=tf * 1e9, step=model.time_step * 1e9, dtype=int)
# Copy the initial data...
data = data0.copy()
# ... and step the simulation.
for _ in T:
with (
data.switch_velocity_representation(velocity_representation),
references.switch_velocity_representation(velocity_representation),
):
data = js.model.step(
model=model,
data=data,
link_forces=references.link_forces(model=model, data=data),
)
# Check that the box moved as expected.
assert_allclose(
data.base_position,
data0.base_position
+ 0.5 * LW_f[:, :3].squeeze() / js.model.total_mass(model=model) * tf**2,
atol=1e-3,
)
def run_simulation(
model: js.model.JaxSimModel,
data_t0: js.data.JaxSimModelData,
tf: jtp.FloatLike,
) -> js.data.JaxSimModelData:
# Initialize the integration horizon.
T_ns = jnp.arange(
start=0.0, stop=int(tf * 1e9), step=int(model.time_step * 1e9)
).astype(int)
# Initialize the simulation data.
data = data_t0.copy()
for _ in T_ns:
data = js.model.step(
model=model,
data=data,
)
return data
def test_simulation_with_soft_contacts(
jaxsim_model_box: js.model.JaxSimModel, integrator
):
model = jaxsim_model_box
# Define the maximum penetration of each collidable point at steady state.
max_penetration = 0.001
with model.editable(validate=False) as model:
model.contact_model = jaxsim.rbda.contacts.SoftContacts.build()
model.contact_params = js.contact.estimate_good_contact_parameters(
model=model,
number_of_active_collidable_points_steady_state=4,
static_friction_coefficient=1.0,
damping_ratio=1.0,
max_penetration=max_penetration,
)
# Enable a subset of the collidable points.
enabled_collidable_points_mask = np.zeros(
len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool
)
enabled_collidable_points_mask[[0, 1, 2, 3]] = True
model.kin_dyn_parameters.contact_parameters.enabled = tuple(
enabled_collidable_points_mask.tolist()
)
assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4
# Check jaxsim_model_box@conftest.py.
box_height = 0.1
# Build the data of the model.
data_t0 = js.data.JaxSimModelData.build(
model=model,
base_position=jnp.array([0.0, 0.0, box_height * 2]),
velocity_representation=VelRepr.Inertial,
)
# ===========================================
# Run the simulation and test the final state
# ===========================================
data_tf = run_simulation(model=model, data_t0=data_t0, tf=1.0)
assert_allclose(data_tf.base_position[0:2], data_t0.base_position[0:2])
assert_allclose(data_tf.base_position[2] + max_penetration, box_height / 2)
def test_simulation_with_rigid_contacts(
jaxsim_model_box: js.model.JaxSimModel, integrator
):
model = jaxsim_model_box
with model.editable(validate=False) as model:
# In order to achieve almost no penetration, we need to use a fairly large
# Baumgarte stabilization term.
model.contact_model = jaxsim.rbda.contacts.RigidContacts.build(
solver_options={"solver_tol": 1e-3}
)
model.contact_params = model.contact_model._parameters_class(K=1e5)
# Enable a subset of the collidable points.
enabled_collidable_points_mask = np.zeros(
len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool
)
enabled_collidable_points_mask[[0, 1, 2, 3]] = True
model.kin_dyn_parameters.contact_parameters.enabled = tuple(
enabled_collidable_points_mask.tolist()
)
assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4
# Initialize the maximum penetration of each collidable point at steady state.
# This model is rigid, so we expect (almost) no penetration.
max_penetration = 0.000
# Check jaxsim_model_box@conftest.py.
box_height = 0.1
# Build the data of the model.
data_t0 = js.data.JaxSimModelData.build(
model=model,
base_position=jnp.array([0.0, 0.0, box_height * 2]),
velocity_representation=VelRepr.Inertial,
)
# ===========================================
# Run the simulation and test the final state
# ===========================================
data_tf = run_simulation(model=model, data_t0=data_t0, tf=1.0)
assert_allclose(data_tf.base_position[0:2], data_t0.base_position[0:2])
assert_allclose(data_tf.base_position[2] + max_penetration, box_height / 2)
def test_simulation_with_relaxed_rigid_contacts(
jaxsim_model_box: js.model.JaxSimModel, integrator
):
model = jaxsim_model_box
with model.editable(validate=False) as model:
model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts.build(
solver_options={"tol": 1e-3},
)
model.contact_params = model.contact_model._parameters_class()
# Enable a subset of the collidable points.
enabled_collidable_points_mask = np.zeros(
len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool
)
enabled_collidable_points_mask[[0, 1, 2, 3]] = True
model.kin_dyn_parameters.contact_parameters.enabled = tuple(
enabled_collidable_points_mask.tolist()
)
model.integrator = integrator
assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4
# Initialize the maximum penetration of each collidable point at steady state.
# This model is quasi-rigid, so we expect (almost) no penetration.
max_penetration = 0.000
# Check jaxsim_model_box@conftest.py.
box_height = 0.1
# Build the data of the model.
data_t0 = js.data.JaxSimModelData.build(
model=model,
base_position=jnp.array([0.0, 0.0, box_height * 2]),
velocity_representation=VelRepr.Inertial,
)
# ===========================================
# Run the simulation and test the final state
# ===========================================
data_tf = run_simulation(model=model, data_t0=data_t0, tf=1.0)
# With this contact model, we need to slightly increase the tolerances.
assert_allclose(data_tf.base_position[0:2], data_t0.base_position[0:2], atol=1e-5)
assert_allclose(
data_tf.base_position[2] + max_penetration, box_height / 2, atol=1e-4
)
def test_joint_limits(
jaxsim_model_single_pendulum: js.model.JaxSimModel,
):
model = jaxsim_model_single_pendulum
with model.editable(validate=False) as model:
model.kin_dyn_parameters.joint_parameters.position_limits_max = jnp.atleast_1d(
jnp.array(1.5708)
)
model.kin_dyn_parameters.joint_parameters.position_limits_min = jnp.atleast_1d(
jnp.array(-1.5708)
)
model.kin_dyn_parameters.joint_parameters.position_limit_spring = (
jnp.atleast_1d(jnp.array(75.0))
)
model.kin_dyn_parameters.joint_parameters.position_limit_damper = (
jnp.atleast_1d(jnp.array(0.1))
)
position_limits_min, position_limits_max = js.joint.position_limits(model=model)
data = js.data.JaxSimModelData.build(
model=model,
velocity_representation=VelRepr.Inertial,
)
theta = 10 * np.pi / 180
# Define a tolerance since the spring-damper model does
# not guarantee that the joint position will be exactly
# below the limit.
tolerance = theta * 0.10
# Test minimum joint position limits.
data_t0 = data.replace(model=model, joint_positions=position_limits_min - theta)
model = model.replace(time_step=0.005, validate=False)
data_tf = run_simulation(model=model, data_t0=data_t0, tf=3.0)
assert (
np.min(np.array(data_tf.joint_positions), axis=0) + tolerance
>= position_limits_min
)
# Test maximum joint position limits.
data_t0 = data.replace(model=model, joint_positions=position_limits_max - theta)
model = model.replace(time_step=0.001)
data_tf = run_simulation(model=model, data_t0=data_t0, tf=3.0)
assert (
np.max(np.array(data_tf.joint_positions), axis=0) - tolerance
<= position_limits_max
)
@pytest.mark.parametrize(
"initial_joint_positions",
[
jnp.array([0, 0]),
np.pi / 180 * jnp.array([5, 0]),
],
)
def test_simulation_with_kinematic_constraints_double_pendulum(
jaxsim_model_double_pendulum: js.model.JaxSimModel,
initial_joint_positions: jtp.Array,
):
# ========
# Arrange
# ========
tf = 1.0 # Final simulation time in seconds.
model = jaxsim_model_double_pendulum
frame_1_name = "right_link_extremity_frame"
frame_2_name = "left_link_extremity_frame"
frame_1_idx = js.frame.name_to_idx(model=model, frame_name=frame_1_name)
frame_2_idx = js.frame.name_to_idx(model=model, frame_name=frame_2_name)
# Define the kinematic constraints.
constraints = js.kin_dyn_parameters.ConstraintMap()
constraints = constraints.add_constraint(
model=model,
frame_idx_1=frame_1_idx,
frame_idx_2=frame_2_idx,
constraint_type=ConstraintType.Weld,
)
# Set the constraints in the model.
with model.editable(validate=False) as model:
model.kin_dyn_parameters.constraints = constraints
model.gravity = 0.0
# Build the initial data for the model.
data_t0 = js.data.JaxSimModelData.build(
model=model,
velocity_representation=VelRepr.Inertial,
joint_positions=initial_joint_positions,
)
# ====
# Act
# ====
# Simulate the model for a given time and time step.
data_tf = run_simulation(model=model, data_t0=data_t0, tf=tf)
# =========
# Assert
# =========
# Assert that the chosen frames exist in the model
assert frame_1_name in model.frame_names()
assert frame_2_name in model.frame_names()
# Assert that the joint positions are now equal
actual_delta_s_tf = jnp.abs(data_tf.joint_positions[0] - data_tf.joint_positions[1])
expected_delta_s_tf = 0.0
assert_allclose(
expected_delta_s_tf,
actual_delta_s_tf,
atol=1e-2,
err_msg=f"Position difference [deg]: {actual_delta_s_tf * 180 / np.pi}",
)
def test_simulation_with_kinematic_constraints_cartpole(
jaxsim_model_cartpole: js.model.JaxSimModel,
):
# ========
# Arrange
# ========
tf = 1.0 # Final simulation time in seconds.
model = jaxsim_model_cartpole
frame_1_name = "cart_frame"
frame_2_name = "rail_frame"
frame_1_idx = js.frame.name_to_idx(model=model, frame_name=frame_1_name)
frame_2_idx = js.frame.name_to_idx(model=model, frame_name=frame_2_name)
# Define the kinematic constraints.
constraints = js.kin_dyn_parameters.ConstraintMap()
constraints = constraints.add_constraint(
model,
frame_1_idx,
frame_2_idx,
ConstraintType.Weld,
)
# Set the initial joint positions with the cart displaced from the rail zero position.
initial_joint_positions = jnp.array([0.05, 0.0])
# Set the constraints in the model.
with model.editable(validate=False) as model:
model.kin_dyn_parameters.constraints = constraints
# Build the initial data for the model.
data_t0 = js.data.JaxSimModelData.build(
model=model,
velocity_representation=VelRepr.Inertial,
joint_positions=initial_joint_positions,
)
# ====
# Act
# ====
# Simulate the model for a given time and time step.
data_tf = run_simulation(model=model, data_t0=data_t0, tf=tf)
H_frame1 = js.frame.transform(
model=model,
data=data_tf,
frame_index=frame_1_idx,
)
H_frame2 = js.frame.transform(
model=model,
data=data_tf,
frame_index=frame_2_idx,
)
# =========
# Assert
# =========
# Assert that the chosen frames exist in the model
assert frame_1_name in model.frame_names()
assert frame_2_name in model.frame_names()
# Assert that the two frames are in the same pose
actual_frame_error = jnp.linalg.inv(H_frame1) @ H_frame2
expected_frame_error = jnp.eye(4)
assert_allclose(actual_frame_error, expected_frame_error, atol=1e-3)
def test_simulation_with_kinematic_constraints_4_bar_linkage(
jaxsim_model_4_bar_linkage: js.model.JaxSimModel,
):
"""Test kinematic weld constraint on 4-bar linkage model."""
# ========
# Arrange
# ========
tf = 1.0 # Final simulation time in seconds.
model = jaxsim_model_4_bar_linkage
frame_1_name = "BC1_frame"
frame_2_name = "BC2_frame"
frame_1_idx = js.frame.name_to_idx(model=model, frame_name=frame_1_name)
frame_2_idx = js.frame.name_to_idx(model=model, frame_name=frame_2_name)
# Define the kinematic constraints.
constraints = js.kin_dyn_parameters.ConstraintMap()
constraints = constraints.add_constraint(
model=model,
frame_idx_1=frame_1_idx,
frame_idx_2=frame_2_idx,
constraint_type=ConstraintType.Weld,
K_P=1e4,
)
# Set the constraints in the model.
with model.editable(validate=False) as model:
model.kin_dyn_parameters.constraints = constraints
# Build the initial data for the model (default base pose is fine).
data_t0 = js.data.JaxSimModelData.build(
model=model,
velocity_representation=VelRepr.Inertial,
base_position=jnp.array([0.0, 0.0, 0.10]),
)
# ====
# Act
# ====
data_tf = run_simulation(model=model, data_t0=data_t0, tf=tf)
H_frame1 = js.frame.transform(
model=model,
data=data_tf,
frame_index=frame_1_idx,
)
H_frame2 = js.frame.transform(
model=model,
data=data_tf,
frame_index=frame_2_idx,
)
# =========
# Assert
# =========
assert frame_1_name in model.frame_names()
assert frame_2_name in model.frame_names()
# Position check
pos1 = H_frame1[:3, 3]
pos2 = H_frame2[:3, 3]
assert_allclose(pos1, pos2, atol=1e-5)
# Orientation check
R1 = H_frame1[:3, :3]
R2 = H_frame2[:3, :3]
R_err = R1.T @ R2
assert_allclose(R_err, jnp.eye(3), atol=1e-3)
================================================
FILE: tests/test_visualizer.py
================================================
import pytest
import rod
from jaxsim.mujoco import ModelToMjcf
from jaxsim.mujoco.loaders import MujocoCamera
@pytest.fixture
def mujoco_camera():
return MujocoCamera.build_from_target_view(
camera_name="test_camera",
lookat=(0, 0, 0),
distance=1,
azimuth=0,
elevation=0,
fovy=45,
degrees=True,
)
def test_urdf_loading(jaxsim_model_single_pendulum, mujoco_camera):
model = jaxsim_model_single_pendulum.built_from
_ = ModelToMjcf.convert(model=model, cameras=mujoco_camera)
def test_sdf_loading(jaxsim_model_single_pendulum, mujoco_camera):
model = rod.Sdf.load(sdf=jaxsim_model_single_pendulum.built_from).serialize(
pretty=True
)
_ = ModelToMjcf.convert(model=model, cameras=mujoco_camera)
def test_rod_loading(jaxsim_model_single_pendulum, mujoco_camera):
model = rod.Sdf.load(sdf=jaxsim_model_single_pendulum.built_from).models()[0]
_ = ModelToMjcf.convert(model=model, cameras=mujoco_camera)
def test_heightmap(jaxsim_model_single_pendulum, mujoco_camera):
model = rod.Sdf.load(sdf=jaxsim_model_single_pendulum.built_from).models()[0]
_ = ModelToMjcf.convert(
model=model,
cameras=mujoco_camera,
heightmap=True,
heightmap_samples_xy=(51, 51),
)
def test_inclined_plane(jaxsim_model_single_pendulum, mujoco_camera):
model = rod.Sdf.load(sdf=jaxsim_model_single_pendulum.built_from).models()[0]
_ = ModelToMjcf.convert(
model=model,
cameras=mujoco_camera,
plane_normal=(0.3, 0.3, 0.3),
)
================================================
FILE: tests/utils.py
================================================
from __future__ import annotations
import dataclasses
import pathlib
import idyntree.bindings as idt
import numpy as np
import numpy.typing as npt
import jaxsim.api as js
from jaxsim import VelRepr
def assert_allclose(actual, desired, rtol=1e-7, atol=1e-9, err_msg=""):
"""
Assert allclose with custom default tolerances.
Normalizes only signed zeros using np.copysign.
"""
actual = np.asarray(actual, dtype=float)
desired = np.asarray(desired, dtype=float)
# Normalize zeros to avoid -0.0 vs 0.0 mismatches.
actual = actual + 0.0
desired = desired + 0.0
np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, err_msg=err_msg)
def build_kindyncomputations_from_jaxsim_model(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
considered_joints: list[str] | None = None,
removed_joint_positions: dict[str, npt.NDArray | float | int] | None = None,
) -> KinDynComputations:
"""
Build a `KinDynComputations` from `JaxSimModel` and `JaxSimModelData`.
Args:
model: The `JaxSimModel` from which to build the `KinDynComputations`.
data: The `JaxSimModelData` from which to build the `KinDynComputations`.
considered_joints:
The list of joint names to consider in the `KinDynComputations`.
removed_joint_positions:
A dictionary defining the positions of the removed joints (default is 0).
Returns:
The `KinDynComputations` built from the `JaxSimModel` and `JaxSimModelData`.
Note:
Only `JaxSimModel` built from URDF files are supported.
"""
if (
isinstance(model.built_from, pathlib.Path)
and model.built_from.suffix != ".urdf"
) or (isinstance(model.built_from, str) and " KinDynComputations:
"""
Store the state of a `JaxSimModelData` in `KinDynComputations`.
Args:
data:
The `JaxSimModelData` providing the desired state to copy.
kin_dyn:
The `KinDynComputations` in which to store the state of `JaxSimModelData`.
Returns:
The updated `KinDynComputations` with the state of `JaxSimModelData`.
"""
if kin_dyn.dofs() != data.joint_positions.size:
raise ValueError(data)
with data.switch_velocity_representation(kin_dyn.vel_repr):
kin_dyn.set_robot_state(
joint_positions=np.array(data.joint_positions),
joint_velocities=np.array(data.joint_velocities),
base_transform=np.array(data._base_transform),
base_velocity=np.array(data.base_velocity),
)
return kin_dyn
@dataclasses.dataclass
class KinDynComputations:
"""High-level wrapper of the iDynTree KinDynComputations class."""
vel_repr: VelRepr
gravity: npt.NDArray
kin_dyn: idt.KinDynComputations
@staticmethod
def build(
urdf: pathlib.Path | str,
considered_joints: list[str] | None = None,
vel_repr: VelRepr = VelRepr.Inertial,
gravity: npt.NDArray = np.array([0, 0, -10.0]),
removed_joint_positions: dict[str, npt.NDArray | float | int] | None = None,
) -> KinDynComputations:
# Read the URDF description.
urdf_string = urdf.read_text() if isinstance(urdf, pathlib.Path) else urdf
# Create the model loader.
mdl_loader = idt.ModelLoader()
# Handle removed_joint_positions if None.
removed_joint_positions = (
{name: float(pos) for name, pos in removed_joint_positions.items()}
if removed_joint_positions is not None
else {}
)
# Load the URDF description.
if not (
mdl_loader.loadModelFromString(urdf_string)
if considered_joints is None
else mdl_loader.loadReducedModelFromString(
urdf_string, considered_joints, removed_joint_positions
)
):
raise RuntimeError("Failed to load URDF description")
# Create KinDynComputations and insert the model.
kindyn = idt.KinDynComputations()
if not kindyn.loadRobotModel(mdl_loader.model()):
raise RuntimeError("Failed to load model")
vel_repr_to_idyntree = {
VelRepr.Inertial: idt.INERTIAL_FIXED_REPRESENTATION,
VelRepr.Body: idt.BODY_FIXED_REPRESENTATION,
VelRepr.Mixed: idt.MIXED_REPRESENTATION,
}
# Configure the frame representation.
if not kindyn.setFrameVelocityRepresentation(vel_repr_to_idyntree[vel_repr]):
raise RuntimeError("Failed to set the frame representation")
return KinDynComputations(
kin_dyn=kindyn,
vel_repr=vel_repr,
gravity=np.array(gravity).squeeze(),
)
def set_robot_state(
self,
joint_positions: npt.NDArray | None = None,
joint_velocities: npt.NDArray | None = None,
base_transform: npt.NDArray = np.eye(4),
base_velocity: npt.NDArray = np.zeros(6),
world_gravity: npt.NDArray | None = None,
) -> None:
joint_positions = (
joint_positions if joint_positions is not None else np.zeros(self.dofs())
)
joint_velocities = (
joint_velocities if joint_velocities is not None else np.zeros(self.dofs())
)
gravity = world_gravity if world_gravity is not None else self.gravity
if joint_positions.size != self.dofs():
raise ValueError(joint_positions.size, self.dofs())
if joint_velocities.size != self.dofs():
raise ValueError(joint_velocities.size, self.dofs())
if gravity.size != 3:
raise ValueError(gravity.size, 3)
if base_transform.shape != (4, 4):
raise ValueError(base_transform.shape, (4, 4))
if base_velocity.size != 6:
raise ValueError(base_velocity.size)
g = idt.Vector3().FromPython(np.array(gravity))
s = idt.VectorDynSize().FromPython(np.array(joint_positions))
s_dot = idt.VectorDynSize().FromPython(np.array(joint_velocities))
p = idt.Position(*[float(i) for i in np.array(base_transform[0:3, 3])])
R = idt.Rotation()
R = R.FromPython(np.array(base_transform[0:3, 0:3]))
world_H_base = idt.Transform()
world_H_base.setPosition(p)
world_H_base.setRotation(R)
v_WB = idt.Twist().FromPython(base_velocity)
if not self.kin_dyn.setRobotState(world_H_base, s, v_WB, s_dot, g):
raise RuntimeError("Failed to set the robot state")
# Update stored gravity.
self.gravity = gravity
def dofs(self) -> int:
return self.kin_dyn.getNrOfDegreesOfFreedom()
def joint_names(self) -> list[str]:
model: idt.Model = self.kin_dyn.model()
return [model.getJointName(i) for i in range(model.getNrOfJoints())]
def link_names(self) -> list[str]:
return [
self.kin_dyn.getFrameName(i) for i in range(self.kin_dyn.getNrOfLinks())
]
def frame_names(self) -> list[str]:
return [
self.kin_dyn.getFrameName(i)
for i in range(self.kin_dyn.getNrOfLinks(), self.kin_dyn.getNrOfFrames())
]
def joint_positions(self) -> npt.NDArray:
vector = idt.VectorDynSize()
if not self.kin_dyn.getJointPos(vector):
raise RuntimeError("Failed to extract joint positions")
return vector.toNumPy()
def joint_velocities(self) -> npt.NDArray:
vector = idt.VectorDynSize()
if not self.kin_dyn.getJointVel(vector):
raise RuntimeError("Failed to extract joint velocities")
return vector.toNumPy()
def jacobian_frame(self, frame_name: str) -> npt.NDArray:
if self.kin_dyn.getFrameIndex(frame_name) < 0:
raise ValueError(f"Frame '{frame_name}' does not exist")
J = idt.MatrixDynSize(6, 6 + self.dofs())
if not self.kin_dyn.getFrameFreeFloatingJacobian(frame_name, J):
raise RuntimeError("Failed to get the frame free-floating jacobian")
return J.toNumPy()
def total_mass(self) -> float:
model: idt.Model = self.kin_dyn.model()
return model.getTotalMass()
def link_spatial_inertia(self, link_name: str) -> npt.NDArray:
if link_name not in self.link_names():
raise ValueError(link_name)
model = self.kin_dyn.model()
link: idt.Link = model.getLink(model.getLinkIndex(link_name))
return link.inertia().asMatrix().toNumPy()
def link_mass(self, link_name: str) -> float:
if link_name not in self.link_names():
raise ValueError(link_name)
model = self.kin_dyn.model()
link: idt.Link = model.getLink(model.getLinkIndex(link_name))
return link.getInertia().asVector().toNumPy()[0]
def floating_base_frame(self) -> str:
return self.kin_dyn.getFloatingBase()
def frame_transform(self, frame_name: str) -> npt.NDArray:
if self.kin_dyn.getFrameIndex(frame_name) < 0:
raise ValueError(f"Frame '{frame_name}' does not exist")
if frame_name == self.floating_base_frame():
H_idt = self.kin_dyn.getWorldBaseTransform()
else:
H_idt = self.kin_dyn.getWorldTransform(frame_name)
H = np.eye(4)
H[0:3, 3] = H_idt.getPosition().toNumPy()
H[0:3, 0:3] = H_idt.getRotation().toNumPy()
return H
def frame_relative_transform(
self, ref_frame_name: str, frame_name: str
) -> npt.NDArray:
if self.kin_dyn.getFrameIndex(ref_frame_name) < 0:
raise ValueError(f"Frame '{ref_frame_name}' does not exist")
if self.kin_dyn.getFrameIndex(frame_name) < 0:
raise ValueError(f"Frame '{frame_name}' does not exist")
ref_H_frame: idt.Transform = self.kin_dyn.getRelativeTransform(
ref_frame_name, frame_name
)
H = np.eye(4)
H[0:3, 3] = ref_H_frame.getPosition().toNumPy()
H[0:3, 0:3] = ref_H_frame.getRotation().toNumPy()
return H
def frame_parent_link_name(self, frame_name: str) -> str:
return self.kin_dyn.model().getLinkName(
self.kin_dyn.model().getFrameLink(
self.kin_dyn.model().getFrameIndex(frame_name)
)
)
def base_velocity(self) -> npt.NDArray:
nu = idt.VectorDynSize()
if not self.kin_dyn.getModelVel(nu):
raise RuntimeError("Failed to get the model velocity")
return nu.toNumPy()[0:6]
def frame_velocity(self, frame_name: str) -> npt.NDArray:
if self.kin_dyn.getFrameIndex(frame_name) < 0:
raise ValueError(f"Frame '{frame_name}' does not exist")
v_WF = self.kin_dyn.getFrameVel(frame_name)
return v_WF.toNumPy()
def frame_bias_acc(self, frame_name: str) -> npt.NDArray:
if self.kin_dyn.getFrameIndex(frame_name) < 0:
raise ValueError(f"Frame '{frame_name}' does not exist")
J̇ν = self.kin_dyn.getFrameBiasAcc(frame_name)
return J̇ν.toNumPy()
def com_position(self) -> npt.NDArray:
W_p_G = self.kin_dyn.getCenterOfMassPosition()
return W_p_G.toNumPy()
def com_velocity(self) -> npt.NDArray:
W_ṗ_G = self.kin_dyn.getCenterOfMassVelocity()
return W_ṗ_G.toNumPy()
def com_bias_acceleration(self) -> npt.NDArray:
return self.kin_dyn.getCenterOfMassBiasAcc().toNumPy()
def mass_matrix(self) -> npt.NDArray:
M = idt.MatrixDynSize()
if not self.kin_dyn.getFreeFloatingMassMatrix(M):
raise RuntimeError("Failed to get the free floating mass matrix")
return M.toNumPy()
def bias_forces(self) -> npt.NDArray:
h = idt.FreeFloatingGeneralizedTorques(self.kin_dyn.model())
if not self.kin_dyn.generalizedBiasForces(h):
raise RuntimeError("Failed to get the generalized bias forces")
base_wrench: idt.Wrench = h.baseWrench()
joint_torques: idt.JointDOFsDoubleArray = h.jointTorques()
return np.hstack(
[base_wrench.toNumPy().flatten(), joint_torques.toNumPy().flatten()]
)
def gravity_forces(self) -> npt.NDArray:
g = idt.FreeFloatingGeneralizedTorques(self.kin_dyn.model())
if not self.kin_dyn.generalizedGravityForces(g):
raise RuntimeError("Failed to get the generalized gravity forces")
base_wrench: idt.Wrench = g.baseWrench()
joint_torques: idt.JointDOFsDoubleArray = g.jointTorques()
return np.hstack(
[base_wrench.toNumPy().flatten(), joint_torques.toNumPy().flatten()]
)
def total_momentum(self) -> npt.NDArray:
return self.kin_dyn.getLinearAngularMomentum().toNumPy().flatten()
def centroidal_momentum(self) -> npt.NDArray:
return self.kin_dyn.getCentroidalTotalMomentum().toNumPy().flatten()
def total_momentum_jacobian(self) -> npt.NDArray:
Jh = idt.MatrixDynSize()
if not self.kin_dyn.getLinearAngularMomentumJacobian(Jh):
raise RuntimeError("Failed to get the total momentum jacobian")
return Jh.toNumPy()
def centroidal_momentum_jacobian(self) -> npt.NDArray:
Jh = idt.MatrixDynSize()
if not self.kin_dyn.getCentroidalTotalMomentumJacobian(Jh):
raise RuntimeError("Failed to get the centroidal momentum jacobian")
return Jh.toNumPy()
def locked_spatial_inertia(self) -> npt.NDArray:
return self.kin_dyn.getRobotLockedInertia().asMatrix().toNumPy()
def locked_centroidal_spatial_inertia(self) -> npt.NDArray:
return self.kin_dyn.getCentroidalRobotLockedInertia().asMatrix().toNumPy()
def average_velocity(self) -> npt.NDArray:
return self.kin_dyn.getAverageVelocity().toNumPy()
def average_velocity_jacobian(self) -> npt.NDArray:
Jh = idt.MatrixDynSize()
if not self.kin_dyn.getAverageVelocityJacobian(Jh):
raise RuntimeError("Failed to get the average velocity jacobian")
return Jh.toNumPy()
def average_centroidal_velocity(self) -> npt.NDArray:
return self.kin_dyn.getCentroidalAverageVelocity().toNumPy()
def average_centroidal_velocity_jacobian(self) -> npt.NDArray:
Jh = idt.MatrixDynSize()
if not self.kin_dyn.getCentroidalAverageVelocityJacobian(Jh):
raise RuntimeError("Failed to get the average centroidal velocity jacobian")
return Jh.toNumPy()