Showing preview only (1,037K chars total). Download the full file or copy to clipboard to get everything.
Repository: ami-iit/jaxsim
Branch: main
Commit: 62ac1bc2707f
Files: 136
Total size: 988.2 KB
Directory structure:
gitextract_5j4outfv/
├── .devcontainer/
│ ├── Dockerfile
│ └── devcontainer.json
├── .gitattributes
├── .github/
│ ├── CODEOWNERS
│ ├── dependabot.yml
│ ├── release.yml
│ └── workflows/
│ ├── ci_cd.yml
│ ├── gpu_benchmark.yml
│ ├── pixi.yml
│ └── read_the_docs.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── CITATION.cff
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs/
│ ├── Makefile
│ ├── conf.py
│ ├── examples.rst
│ ├── guide/
│ │ ├── configuration.rst
│ │ └── install.rst
│ ├── index.rst
│ ├── make.bat
│ └── modules/
│ ├── api.rst
│ ├── math.rst
│ ├── mujoco.rst
│ ├── parsers.rst
│ ├── rbda.rst
│ ├── typing.rst
│ └── utils.rst
├── environment.yml
├── examples/
│ ├── .gitattributes
│ ├── .gitignore
│ ├── README.md
│ ├── assets/
│ │ ├── build_cartpole_urdf.py
│ │ └── cartpole.urdf
│ ├── jaxsim_as_multibody_dynamics_library.ipynb
│ ├── jaxsim_as_physics_engine.ipynb
│ ├── jaxsim_as_physics_engine_advanced.ipynb
│ └── jaxsim_for_robot_controllers.ipynb
├── pyproject.toml
├── src/
│ └── jaxsim/
│ ├── __init__.py
│ ├── api/
│ │ ├── __init__.py
│ │ ├── actuation_model.py
│ │ ├── com.py
│ │ ├── common.py
│ │ ├── contact.py
│ │ ├── data.py
│ │ ├── frame.py
│ │ ├── integrators.py
│ │ ├── joint.py
│ │ ├── kin_dyn_parameters.py
│ │ ├── link.py
│ │ ├── model.py
│ │ ├── ode.py
│ │ └── references.py
│ ├── exceptions.py
│ ├── logging.py
│ ├── math/
│ │ ├── __init__.py
│ │ ├── adjoint.py
│ │ ├── cross.py
│ │ ├── inertia.py
│ │ ├── joint_model.py
│ │ ├── quaternion.py
│ │ ├── rotation.py
│ │ ├── skew.py
│ │ ├── transform.py
│ │ └── utils.py
│ ├── mujoco/
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ ├── loaders.py
│ │ ├── model.py
│ │ ├── utils.py
│ │ └── visualizer.py
│ ├── parsers/
│ │ ├── __init__.py
│ │ ├── descriptions/
│ │ │ ├── __init__.py
│ │ │ ├── collision.py
│ │ │ ├── joint.py
│ │ │ ├── link.py
│ │ │ └── model.py
│ │ ├── kinematic_graph.py
│ │ └── rod/
│ │ ├── __init__.py
│ │ ├── meshes.py
│ │ ├── parser.py
│ │ └── utils.py
│ ├── rbda/
│ │ ├── __init__.py
│ │ ├── aba.py
│ │ ├── aba_parallel.py
│ │ ├── actuation/
│ │ │ ├── __init__.py
│ │ │ └── common.py
│ │ ├── collidable_points.py
│ │ ├── contacts/
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── relaxed_rigid.py
│ │ │ ├── rigid.py
│ │ │ └── soft.py
│ │ ├── crba.py
│ │ ├── forward_kinematics.py
│ │ ├── forward_kinematics_parallel.py
│ │ ├── jacobian.py
│ │ ├── kinematic_constraints.py
│ │ ├── mass_inverse.py
│ │ ├── rnea.py
│ │ └── utils.py
│ ├── terrain/
│ │ ├── __init__.py
│ │ └── terrain.py
│ ├── typing.py
│ └── utils/
│ ├── __init__.py
│ ├── jaxsim_dataclass.py
│ ├── tracing.py
│ └── wrappers.py
└── tests/
├── __init__.py
├── assets/
│ ├── 4_bar_opened.urdf
│ ├── cube.stl
│ ├── double_pendulum.sdf
│ ├── mixed_shapes_robot.urdf
│ └── test_cube.urdf
├── conftest.py
├── test_actuation.py
├── test_api_com.py
├── test_api_contact.py
├── test_api_data.py
├── test_api_frame.py
├── test_api_joint.py
├── test_api_link.py
├── test_api_model.py
├── test_api_model_hw_parametrization.py
├── test_automatic_differentiation.py
├── test_benchmark.py
├── test_exceptions.py
├── test_meshes.py
├── test_pytree.py
├── test_simulations.py
├── test_visualizer.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .devcontainer/Dockerfile
================================================
# syntax=docker/dockerfile:1.4
FROM mcr.microsoft.com/devcontainers/base:jammy
ARG PROJECT_NAME=jaxsim
ARG PIXI_VERSION=v0.35.0
RUN curl -o /usr/local/bin/pixi -SL https://github.com/prefix-dev/pixi/releases/download/${PIXI_VERSION}/pixi-$(uname -m)-unknown-linux-musl \
&& chmod +x /usr/local/bin/pixi \
&& pixi info
# Add LFS repository and install.
RUN apt-get update && apt-get install -y curl \
&& curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash \
&& apt install -y git-lfs
USER vscode
WORKDIR /home/vscode
RUN echo 'eval "$(pixi completion -s bash)"' >> /home/vscode/.bashrc
================================================
FILE: .devcontainer/devcontainer.json
================================================
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/ubuntu
{
"name": "Ubuntu",
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
"build": {
"context": "..",
"dockerfile": "Dockerfile"
},
// Features to add to the dev container. More info: https://containers.dev/features.
"features": {
"ghcr.io/devcontainers/features/docker-in-docker:2": {}
},
// Put `.pixi` folder in a mounted volume of a case-insensitive filesystem.
"mounts": ["source=${localWorkspaceFolderBasename}-pixi,target=${containerWorkspaceFolder}/.pixi,type=volume"],
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],
// Use 'postCreateCommand' to run commands after the container is created.
"postCreateCommand": "sudo chown vscode .pixi && git lfs pull --include='pixi.lock' && pixi install --environment=test-cpu",
// Configure tool-specific properties.
// "customizations": {},
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
// VSCode extensions
"customizations": {
"vscode": {
"settings": {
"python.pythonPath": "/workspaces/jaxsim/.pixi/envs/test-cpu/bin/python",
"python.defaultInterpreterPath": "/workspaces/jaxsim/.pixi/envs/test-cpu/bin/python",
"python.terminal.activateEnvironment": true,
"python.terminal.activateEnvInCurrentTerminal": true
},
"extensions": [
"ms-python.python",
"donjayamanne.python-extension-pack",
"ms-toolsai.jupyter",
"GitHub.codespaces",
"GitHub.copilot",
"ms-azuretools.vscode-docker",
"charliermarsh.ruff"
]
}
}
}
================================================
FILE: .gitattributes
================================================
# GitHub syntax highlighting
pixi.lock filter=lfs diff=lfs merge=lfs -text
================================================
FILE: .github/CODEOWNERS
================================================
* @flferretti
================================================
FILE: .github/dependabot.yml
================================================
version: 2
updates:
# Check for updates to GitHub Actions every month.
- package-ecosystem: github-actions
directory: /
schedule:
interval: monthly
# Disable rebasing automatically existing pull requests.
rebase-strategy: "disabled"
# Group updates to a single PR.
groups:
dependencies:
patterns:
- '*'
================================================
FILE: .github/release.yml
================================================
changelog:
exclude:
authors:
- dependabot[bot]
- pre-commit-ci[bot]
- github-actions[bot]
================================================
FILE: .github/workflows/ci_cd.yml
================================================
name: Python CI/CD
on:
workflow_dispatch:
push:
pull_request:
release:
types:
- published
schedule:
# Execute a nightly build at 2am UTC.
- cron: '0 2 * * *'
jobs:
package:
name: Package the project
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.11"
- name: Install Python tools
run: pip install build twine
- name: Create distributions
run: python -m build -o dist/
- name: Inspect dist folder
run: ls -lah dist/
- name: Check wheel's abi and platform tags
run: test $(find dist/ -name *-none-any.whl | wc -l) -gt 0
- name: Run twine check
run: twine check dist/*
- name: Upload artifacts
uses: actions/upload-artifact@v7
with:
path: dist/*
name: dist
test:
name: 'Python${{ matrix.python }}@${{ matrix.os }}'
needs: package
runs-on: ${{ matrix.os }}
env:
PYTHONUTF8: "1"
strategy:
fail-fast: false
matrix:
os:
- ubuntu-latest
- macos-latest
- windows-latest
python:
- "3.10"
- "3.11"
- "3.12"
- "3.13"
steps:
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python }}
- name: Download Python packages
uses: actions/download-artifact@v8
with:
path: dist
name: dist
- name: Install wheel (ubuntu)
if: contains(matrix.os, 'ubuntu')
shell: bash
run: pip install "$(find dist/ -type f -name '*.whl')"
- name: Install wheel (macos|windows)
if: contains(matrix.os, 'macos') || contains(matrix.os, 'windows')
shell: bash
run: pip install "$(find dist/ -type f -name '*.whl')"
- name: Document installed pip packages
shell: bash
run: pip list --verbose
- name: Import the package
run: python -c "import jaxsim"
- uses: actions/checkout@v6
with:
lfs: true
- uses: prefix-dev/setup-pixi@v0.9.5
if: contains(matrix.os, 'ubuntu')
with:
pixi-version: "latest"
frozen: true
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
- name: Run the Python tests
if: |
contains(matrix.os, 'ubuntu') &&
(github.event_name != 'pull_request')
run: pixi run --frozen test --numprocesses auto
env:
# https://github.com/pytest-dev/pytest/issues/7443#issuecomment-656642591
PY_COLORS: "1"
JAX_PLATFORM_NAME: cpu
publish:
name: Publish to PyPI
needs: test
runs-on: ubuntu-latest
permissions:
id-token: write
steps:
- name: Download Python packages
uses: actions/download-artifact@v8
with:
path: dist
name: dist
- name: Inspect dist folder
run: ls -lah dist/
- name: Publish to PyPI
if: |
github.repository == 'gbionics/jaxsim' &&
((github.event_name == 'push' && github.ref == 'refs/heads/main') ||
(github.event_name == 'release'))
uses: pypa/gh-action-pypi-publish@release/v1
with:
skip-existing: true
================================================
FILE: .github/workflows/gpu_benchmark.yml
================================================
name: GPU Benchmarks
on:
push:
branches:
- main
pull_request:
types: [opened, reopened, synchronize]
workflow_dispatch:
schedule:
- cron: "0 0 * * 1" # Run At 00:00 on Monday
permissions:
pull-requests: write
deployments: write
contents: write
jobs:
benchmark:
runs-on: self-hosted
container:
image: ghcr.io/prefix-dev/pixi:0.46.0-noble@sha256:c12bcbe8ba5dfd71867495d3471b95a6993b79cc7de7eafec016f8f59e4e4961
options: --rm --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -e "TERM=xterm-256color"
steps:
- name: Install Git and Git-LFS
run: |
apt update && apt install -y git git-lfs
- name: Checkout repository
uses: actions/checkout@v6
with:
lfs: true
fetch-depth: 0
- name: Fetch pixi.lock from LFS
run: |
git config --global safe.directory /__w/jaxsim/jaxsim
git lfs checkout pixi.lock
- name: Get main branch SHA
id: get-main-branch-sha
run: |
SHA=$(git rev-parse origin/main)
echo "sha=$SHA" >> $GITHUB_OUTPUT
- name: Get benchmark results from main branch
id: cache
uses: actions/cache/restore@v5
with:
path: ./cache
key: ${{ runner.os }}-benchmark
- name: Run benchmark and store result
run: |
pixi run --frozen --environment gpu benchmark --batch-size 128 --benchmark-json output.json
env:
PY_COLORS: "1"
- name: Compare benchmark results with main branch
uses: benchmark-action/github-action-benchmark@v1.22.0
with:
tool: 'pytest'
output-file-path: output.json
external-data-json-path: ./cache/benchmark-data.json
save-data-file: false
fail-on-alert: true
summary-always: true
comment-always: true
alert-threshold: 150%
github-token: ${{ secrets.GITHUB_TOKEN }}
- name: Store benchmark result for main branch
uses: benchmark-action/github-action-benchmark@v1.22.0
if: ${{ github.ref_name == 'main' }}
with:
tool: 'pytest'
output-file-path: output.json
external-data-json-path: ./cache/benchmark-data.json
save-data-file: true
fail-on-alert: false
summary-always: true
comment-always: true
alert-threshold: 150%
github-token: ${{ secrets.GITHUB_TOKEN }}
- name: Publish Benchmark Results to GitHub Pages
uses: benchmark-action/github-action-benchmark@v1.22.0
if: ${{ github.ref_name == 'main' }}
with:
tool: 'pytest'
output-file-path: output.json
benchmark-data-dir-path: "benchmarks"
fail-on-alert: false
github-token: ${{ secrets.GITHUB_TOKEN }}
comment-on-alert: true
summary-always: true
save-data-file: true
alert-threshold: "150%"
auto-push: ${{ github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' }}
- name: Update Benchmark Results cache
uses: actions/cache/save@v5
if: ${{ github.ref_name == 'main' }}
with:
path: ./cache
key: ${{ runner.os }}-benchmark
================================================
FILE: .github/workflows/pixi.yml
================================================
name: Pixi
permissions:
contents: write
pull-requests: write
on:
workflow_dispatch:
schedule:
# Execute at 5am UTC on the first day of the month.
- cron: '0 5 1 * *'
jobs:
pixi-update:
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v6
with:
lfs: true
- name: Set up pixi
uses: prefix-dev/setup-pixi@v0.9.5
with:
run-install: false
- name: Install pixi-diff-to-markdown
run: pixi global install pixi-diff-to-markdown
- name: Update pixi lockfile and generate diff
run: |
set -o pipefail
pixi update --json | pixi exec pixi-diff-to-markdown --explicit-column > diff.md
- name: Test project against updated pixi
run: pixi run --environment default test
env:
PY_COLORS: "1"
JAX_PLATFORM_NAME: cpu
- name: Commit and push changes
run: echo "BRANCH_NAME=update-pixi-$(date +'%Y%m%d%H%M%S')" >> $GITHUB_ENV
- name: Create pull request
uses: peter-evans/create-pull-request@v8
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: Update `pixi.lock`
title: Update `pixi` lockfile
body-path: diff.md
branch: ${{ env.BRANCH_NAME }}
base: main
labels: pixi
add-paths: pixi.lock
delete-branch: true
committer: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
author: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
================================================
FILE: .github/workflows/read_the_docs.yml
================================================
name: Read the Docs PR
on:
pull_request_target:
types:
- opened
permissions:
pull-requests: write
jobs:
documentation-links:
runs-on: ubuntu-latest
steps:
- uses: readthedocs/actions/preview@v1
with:
project-slug: "jaxsim"
project-language: ""
================================================
FILE: .gitignore
================================================
# IDEs
.idea*
.vscode/
# Matlab
*.m~
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
docs/_collections/
docs/modules/_autosummary/
docs/modules/generated
docs/sg_execution_times.rst
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# dynamic version
src/jaxsim/_version.py
# ruff
.ruff_cache/
# pixi environments
.pixi
# data
.mp4
.png
================================================
FILE: .pre-commit-config.yaml
================================================
ci:
autofix_prs: false
autoupdate_schedule: quarterly
submodules: false
default_language_version:
python: python3
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-ast
- id: check-merge-conflict
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-toml
- id: check-added-large-files
args: ["--maxkb=2000"]
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 26.3.1
hooks:
- id: black
args: ["--check", "--diff"]
- repo: https://github.com/pycqa/isort
rev: 8.0.1
hooks:
- id: isort
args: ["--check", "--diff"]
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
- id: rst-backticks
- id: rst-directive-colons
- id: rst-inline-touching-normal
- repo: https://github.com/codespell-project/codespell
rev: v2.4.2
hooks:
- id: codespell
args: ["-S", "*.lock"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.9
hooks:
- id: ruff
- repo: https://github.com/kynan/nbstripout
rev: 0.9.1
hooks:
- id: nbstripout
================================================
FILE: .readthedocs.yaml
================================================
version: "2"
build:
os: ubuntu-24.04
tools:
python: "mambaforge-23.11"
conda:
environment: environment.yml
python:
install:
- method: pip
path: .
sphinx:
configuration: docs/conf.py
formats: all
================================================
FILE: CITATION.cff
================================================
cff-version: 1.2.0
title: JaxSim
message: "If you use this software, please cite the paper."
type: software
authors:
- family-names: Ferretti
given-names: Filippo Luca
affiliation: "Generative Bionics"
- family-names: Ferigo
given-names: Diego
affiliation: "Robotics & AI Institute"
- family-names: Croci
given-names: Alessandro
affiliation: "NEURA Robotics"
- family-names: Sartore
given-names: Carlotta
affiliation: "Generative Bionics"
- family-names: Younis
given-names: Omar G.
affiliation: "Quebec AI Institute"
- family-names: Traversaro
given-names: Silvio
affiliation: "Generative Bionics"
- family-names: Pucci
given-names: Daniele
affiliation: "Generative Bionics"
repository-code: "https://github.com/gbionics/jaxsim"
license: BSD-3-Clause
preferred-citation:
type: article
title: "Contact-Aware Morphology Optimization via Physically Consistent Differentiable Simulation"
authors:
- family-names: Ferretti
given-names: Filippo Luca
- family-names: Ferigo
given-names: Diego
- family-names: Croci
given-names: Alessandro
- family-names: Sartore
given-names: Carlotta
- family-names: Younis
given-names: Omar G.
- family-names: Traversaro
given-names: Silvio
- family-names: Pucci
given-names: Daniele
journal: "IEEE Robotics and Automation Letters"
year: 2026
doi: "10.1109/LRA.2026.3678125"
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to JAXsim :rocket:
Hello Contributor,
We're thrilled that you're considering contributing to JAXsim!
Here's a brief guide to help you seamlessly become a part of our project.
## Development Environment :hammer_and_wrench:
Make sure your development environment is set up.
Follow the installation instructions in the [README](./README.md) to get JAXsim and its dependencies up and running.
To ensure consistency and maintain code quality, we recommend using the pre-commit hook with the following configuration.
This will help catch issues before they become a part of the project.
### Setting Up Pre-commit Hook :fishing_pole_and_fish:
`pre-commit` is a tool that manages pre-commit hooks for your project.
It will run checks on your code before you commit it, ensuring that it meets the project's standards.
First, install `pre-commit` if you haven't already:
```bash
pip install pre-commit
```
Then, run the following command to install the hooks:
```bash
pre-commit install
```
### Using Pre-commit Hook :vertical_traffic_light:
Before making any commits, the pre-commit hook will automatically run.
If it finds any issues, it will prevent the commit and provide instructions on how to fix them.
To get your commit through without fixing the issues, use the `--no-verify` flag:
```bash
git commit -m "Your commit message" --no-verify
```
To manually run the pre-commit hook at any time, use:
```bash
pre-commit run --all-files
```
## Making Changes :construction:
Before submitting a pull request, create an issue to discuss your changes if major changes are involved.
This helps us understand your needs and provide feedback.
Clearly describe your pull request, referencing any related issues.
Follow the [PEP 8](https://peps.python.org/pep-0008/) style guide and include relevant tests.
## Testing :test_tube:
Your code will be tested with the CI/CD pipeline before merging.
Feel free to add new ones or update the existing tests in the [workflows](./.github/workflows) folder to cover your changes.
## Documentation :book:
Update the documentation in the [docs](./docs) folder and the [README](./README.md) to reflect your changes, if necessary.
There is no need to build the documentation locally; it will be automatically built and deployed with your pull request, where a preview link will be provided.
## Code Review :eyes:
Expect feedback during the code review process.
Address comments and make necessary changes.
This collaboration ensures quality.
Please keep the commit history clean, or squash commits if necessary.
## License :scroll:
JAXsim is under the [BSD 3-Clause License](./LICENSE).
By contributing, you agree to the same license.
Thank you for contributing to JAXsim! Your efforts are appreciated.
================================================
FILE: LICENSE
================================================
BSD 3-Clause License
Copyright (c) 2022, Artificial and Mechanical Intelligence
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: README.md
================================================
# JaxSim
**JaxSim** is a **differentiable physics engine** built with JAX, tailored for co-design and robotic learning applications.
<div align="center">
<br/>
<table>
<tr>
<th><img src="https://github.com/user-attachments/assets/89d0b4ca-7e0c-4f58-bf3e-9540e35b9a01" style="height:300px; width:400px; object-fit:cover;"></th>
<th><img src="https://github.com/user-attachments/assets/a909e388-d7b4-4b58-89f1-035da8636d94" style="height:300px; width:400px; object-fit:cover;"></th>
</tr>
<tr>
<th><img src="https://github.com/user-attachments/assets/3692bc06-18ed-406d-80bd-480780346224" style="height:300px; width:400px; object-fit:cover;"></th>
<th><img src="https://github.com/user-attachments/assets/3356f332-4710-4946-9a82-a8c2305dab88" style="height:300px; width:400px; object-fit:cover;"></th>
</tr>
</table>
<br/>
</div>
## Features
- Physically consistent differentiability w.r.t. hardware parameters.
- Closed chain dynamics support.
- Reduced-coordinate physics engine for **fixed-base** and **floating-base** robots.
- Fully Python-based, leveraging [JAX][jax] following a functional programming paradigm.
- Seamless execution on CPUs, GPUs, and TPUs.
- Supports JIT compilation and automatic vectorization for high performance.
- Compatible with SDF models and URDF (via [sdformat][sdformat] conversion).
> [!WARNING]
> This project is still experimental. APIs may change between releases without notice.
> [!NOTE]
> JaxSim currently focuses on locomotion applications.
> Only contacts between bodies and smooth ground surfaces are supported.
## How to use it
```python
import pathlib
import icub_models
import jax.numpy as jnp
import jaxsim.api as js
# Load the iCub model
model_path = icub_models.get_model_file("iCubGazeboV2_5")
joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',
'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',
'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch',
'r_ankle_roll')
# Build and reduce the model
model_description = pathlib.Path(model_path)
full_model = js.model.JaxSimModel.build_from_model_description(
model_description=model_description,
)
model = js.model.reduce(model=full_model, considered_joints=joints)
# Get the number of degrees of freedom
ndof = model.dofs()
# Initialize data and simulation
# Note that the default data representation is mixed velocity representation
data = js.data.JaxSimModelData.build(
model=model, base_position=jnp.array([0.0, 0.0, 1.0])
)
T = jnp.arange(start=0, stop=1.0, step=model.time_step)
tau = jnp.zeros(ndof)
# Simulate
for _ in T:
data = js.model.step(
model=model, data=data, link_forces=None, joint_force_references=tau
)
```
Check the example folder for additional use cases!
[jax]: https://github.com/google/jax/
[sdformat]: https://github.com/gazebosim/sdformat
[notation]: https://research.tue.nl/en/publications/multibody-dynamics-notation-version-2
[passive_viewer_mujoco]: https://mujoco.readthedocs.io/en/stable/python.html#passive-viewer
## Installation
<details>
<summary>With <code>conda</code></summary>
You can install the project using [`conda`][conda] as follows:
```bash
conda install jaxsim -c conda-forge
```
GPU support for JAX will be automatically installed if a compatible GPU is detected.
</details>
<details>
<summary>With <code>pixi</code></summary>
> ### Note
> The minimum version of `pixi` required is `0.39.0`.
Since the `pixi.lock` file is stored using Git LFS, make sure you have [Git LFS](https://github.com/git-lfs/git-lfs/blob/main/INSTALLING.md) installed and properly configured on your system before installation. After cloning the repository, run:
```bash
git lfs install && git lfs pull
```
This ensures all LFS-tracked files are properly downloaded before you proceed with the installation.
You can add the `jaxsim` dependency in your [`pixi`][pixi] project as follows:
```bash
pixi add jaxsim
```
If you are on Linux and you want to use a `cuda`-powered version of `jax`, remember to add the appropriate line in the [`system-requirements`](https://pixi.sh/latest/reference/pixi_manifest/#the-system-requirements-table) table, i.e. adding
~~~toml
[system-requirements]
cuda = "13"
~~~
if you are using a `pixi.toml` file or
~~~toml
[tool.pixi.system-requirements]
cuda = "13"
~~~
if you are using a `pyproject.toml` file.
</details>
<details>
<summary>With <code>pip</code></summary>
You can install the project using [`pypa/pip`][pip], preferably in a [virtual environment][venv], as follows:
```bash
pip install jaxsim
```
Check [`pyproject.toml`](pyproject.toml) for the complete list of optional dependencies.
You can obtain a full installation using `jaxsim[all]`.
If you need URDF support, follow the [official instructions](https://gazebosim.org/docs) to install Gazebo Sim on your operating system,
making sure to obtain `sdformat ≥ 13.0` and `gz-tools ≥ 2.0`.
You don't need to install the entire Gazebo Sim suite.
For example, on Ubuntu, it is sufficient to install the `libsdformat*` and `gz-tools2` packages.
If you need GPU support, follow the official [installation instructions][jax_gpu] of JAX.
</details>
<details>
<summary>Contributors installation (with <code>conda</code>)</summary>
If you want to contribute to the project, we recommend creating the following `jaxsim` conda environment first:
```bash
conda env create -f environment.yml
```
Then, activate the environment and install the project in editable mode:
```bash
conda activate jaxsim
pip install --no-deps -e .
```
</details>
<details>
<summary>Contributors installation (with <code>pixi</code>)</summary>
> ### Note
> The minimum version of `pixi` required is `0.39.0`.
Since the `pixi.lock` file is stored using Git LFS, make sure you have [Git LFS](https://github.com/git-lfs/git-lfs/blob/main/INSTALLING.md) installed and properly configured on your system before installation. After cloning the repository, run:
```bash
git lfs install && git lfs pull
```
This ensures all LFS-tracked files are properly downloaded before you proceed with the installation.
You can install the default dependencies of the project using [`pixi`][pixi] as follows:
```bash
pixi install
```
See `pixi task list` for a list of available tasks.
</details>
[conda]: https://anaconda.org/
[pip]: https://github.com/pypa/pip/
[pixi]: https://pixi.sh/
[venv]: https://docs.python.org/3/tutorial/venv.html
[jax_gpu]: https://github.com/google/jax/#installation
## Documentation
The JaxSim API documentation is available at [jaxsim.readthedocs.io][readthedocs].
[readthedocs]: https://jaxsim.readthedocs.io/
## Additional features
Jaxsim can also be used as a multi-body dynamics library! With full support for automatic differentiation of RBDAs (forwards and reverse mode) and automatic differentiation against both kinematic and dynamic parameters.
### Using JaxSim as a multibody dynamics library
```python
import pathlib
import icub_models
import jax.numpy as jnp
import jaxsim.api as js
# Load the iCub model
model_path = icub_models.get_model_file("iCubGazeboV2_5")
joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',
'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',
'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch',
'r_ankle_roll')
# Build and reduce the model
model_description = pathlib.Path(model_path)
full_model = js.model.JaxSimModel.build_from_model_description(
model_description=model_description,
)
model = js.model.reduce(model=full_model, considered_joints=joints)
# Initialize model data
data = js.data.JaxSimModelData.build(
model=model,
base_position=jnp.array([0.0, 0.0, 1.0]),
)
# Frame and dynamics computations
frame_index = js.frame.name_to_idx(model=model, frame_name="l_foot")
# Frame transformation
W_H_F = js.frame.transform(
model=model, data=data, frame_index=frame_index
)
# Frame Jacobian
W_J_F = js.frame.jacobian(
model=model, data=data, frame_index=frame_index
)
# Dynamics properties
M = js.model.free_floating_mass_matrix(model=model, data=data) # Mass matrix
h = js.model.free_floating_bias_forces(model=model, data=data) # Bias forces
g = js.model.free_floating_gravity_forces(model=model, data=data) # Gravity forces
C = js.model.free_floating_coriolis_matrix(model=model, data=data) # Coriolis matrix
# Print dynamics results
print(f"{M.shape=} \n{h.shape=} \n{g.shape=} \n{C.shape=}")
```
## Credits
The RBDAs are based on the theory of the [Rigid Body Dynamics Algorithms][RBDA]
book by Roy Featherstone.
The algorithms and some simulation features were inspired by its accompanying [code][spatial_v2].
[RBDA]: https://link.springer.com/book/10.1007/978-1-4899-7560-7
[spatial_v2]: http://royfeatherstone.org/spatial/index.html#spatial-software
The development of JaxSim started in late 2021, inspired by early versions of [`google/brax`][brax].
At that time, Brax was implemented in maximal coordinates, and we wanted a physics engine in reduced coordinates.
We are grateful to the Brax team for their work and for showing the potential of [JAX][jax] in this field.
Brax v2 was later implemented with reduced coordinates, following an approach comparable to JaxSim.
The development then shifted to [MJX][mjx], which provides a JAX-based implementation of the Mujoco APIs.
The main differences between MJX/Brax and JaxSim are as follows:
- JaxSim supports out-of-the-box all SDF models with [Pose Frame Semantics][PFS].
- JaxSim only supports collisions between points rigidly attached to bodies and a compliant ground surface.
[brax]: https://github.com/google/brax
[mjx]: https://mujoco.readthedocs.io/en/3.0.0/mjx.html
[PFS]: http://sdformat.org/tutorials?tut=pose_frame_semantics
## Contributing
We welcome contributions from the community.
Please read the [contributing guide](./CONTRIBUTING.md) to get started.
## Citing
If you use JaxSim in your work, please cite the following paper:
```bibtex
@article{ferretti_contact_aware_2026,
author = {Filippo Luca Ferretti and Diego Ferigo and Alessandro Croci and Carlotta Sartore and Omar G. Younis and Silvio Traversaro and Daniele Pucci},
title = {Contact-Aware Morphology Optimization via Physically Consistent Differentiable Simulation},
journal = {IEEE Robotics and Automation Letters},
year = {2026},
doi = {10.1109/LRA.2026.3678125}
}
```
## People
| Authors | Maintainer |
|:------:|:-----------:|
| [<img src="https://avatars.githubusercontent.com/u/469199?v=4" width="40">][df] [<img src="https://avatars.githubusercontent.com/u/102977828?v=4" width="40">][ff] | [<img src="https://avatars.githubusercontent.com/u/102977828?v=4" width="40">][ff] |
[df]: https://github.com/diegoferigo
[ff]: https://github.com/flferretti
## License
[BSD3](https://choosealicense.com/licenses/bsd-3-clause/)
================================================
FILE: docs/Makefile
================================================
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
SPHINXPROJ = JAXsim
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
================================================
FILE: docs/conf.py
================================================
# Configuration file for the Sphinx documentation builder.
import os
import sys
if os.environ.get("READTHEDOCS"):
checkout_name = os.path.basename(os.path.dirname(os.path.realpath(__file__)))
os.environ["CONDA_PREFIX"] = os.path.realpath(
os.path.join("..", "..", "conda", checkout_name)
)
import jaxsim
# -- Version information
sys.path.insert(0, os.path.abspath("."))
sys.path.insert(0, os.path.abspath("../"))
sys.path.insert(0, os.path.abspath("../../"))
module_path = os.path.abspath("../src/")
sys.path.insert(0, module_path)
__version__ = jaxsim._version.__version__
# -- Project information
project = "JAXsim"
copyright = "2022, Artificial and Mechanical Intelligence"
author = "Artificial and Mechanical Intelligence"
release = version = __version__
# -- General configuration
extensions = [
"sphinx.ext.duration",
"sphinx.ext.doctest",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinx.ext.mathjax",
"sphinx.ext.ifconfig",
"sphinx.ext.viewcode",
"sphinx_rtd_theme",
"sphinx.ext.napoleon",
"sphinx_autodoc_typehints",
"sphinx_multiversion",
"myst_nb",
"sphinx_gallery.gen_gallery",
"sphinxcontrib.collections",
"sphinx_design",
]
# -- Options for intersphinx extension
language = "en"
html_theme = "sphinx_book_theme"
templates_path = ["_templates"]
html_title = f"JAXsim {version}"
master_doc = "index"
autodoc_typehints_format = "short"
autodoc_typehints = "description"
autosummary_generate = True
epub_show_urls = "footnote"
# Enable postponed evaluation of annotations (PEP 563)
autodoc_type_aliases = {
"jaxsim.typing.PyTree": "jaxsim.typing.PyTree",
"jaxsim.typing.Vector": "jaxsim.typing.Vector",
"jaxsim.typing.Matrix": "jaxsim.typing.Matrix",
"jaxsim.typing.Array": "jaxsim.typing.Array",
"jaxsim.typing.Int": "jaxsim.typing.Int",
"jaxsim.typing.Bool": "jaxsim.typing.Bool",
"jaxsim.typing.Float": "jaxsim.typing.Float",
"jaxsim.typing.ScalarLike": "jaxsim.typing.ScalarLike",
"jaxsim.typing.ArrayLike": "jaxsim.typing.ArrayLike",
"jaxsim.typing.VectorLike": "jaxsim.typing.VectorLike",
"jaxsim.typing.MatrixLike": "jaxsim.typing.MatrixLike",
"jaxsim.typing.IntLike": "jaxsim.typing.IntLike",
"jaxsim.typing.BoolLike": "jaxsim.typing.BoolLike",
"jaxsim.typing.FloatLike": "jaxsim.typing.FloatLike",
}
# -- Options for sphinx-collections
collections = {
"examples": {"driver": "copy_folder", "source": "../examples/", "ignore": "assets"}
}
# -- Options for sphinx-gallery ----------------------------------------------
sphinx_gallery_conf = {
"examples_dirs": "../examples",
"gallery_dirs": "../generated_examples/",
"doc_module": "jaxsim",
}
# -- Options for myst -------------------------------------------------------
myst_enable_extensions = [
"amsmath",
"dollarmath",
]
nb_execution_mode = "auto"
nb_execution_raise_on_error = True
nb_render_image_options = {
"scale": "60",
}
nb_execution_timeout = 180
source_suffix = [".rst", ".md", ".ipynb"]
# Ignore header warnings
suppress_warnings = ["myst.header"]
================================================
FILE: docs/examples.rst
================================================
.. _collections:
Example Notebooks
=================
.. toctree::
:glob:
:hidden:
:maxdepth: 1
_collections/examples/README.md
.. raw:: html
<div class="sphx-glr-thumbnails">
<div class="sphx-glr-thumbcontainer" tooltip="JaxSim as a hardware-accelerated parallel physics engine">
.. only:: html
:doc:`_collections/examples/jaxsim_as_physics_engine`
.. raw:: html
<div class="sphx-glr-thumbnail-title">JaxSim as a hardware-accelerated parallel physics engine</div>
</div>
<div class="sphx-glr-thumbcontainer" tooltip="JaxSim as a hardware-accelerated parallel physics engine [Advanced]">
.. only:: html
:doc:`_collections/examples/jaxsim_as_physics_engine_advanced`
.. raw:: html
<div class="sphx-glr-thumbnail-title">JaxSim as a hardware-accelerated parallel physics engine [Advanced]</div>
</div>
<div class="sphx-glr-thumbcontainer" tooltip="JaxSim as a multibody dynamics library">
.. only:: html
:doc:`_collections/examples/jaxsim_as_multibody_dynamics_library`
.. raw:: html
<div class="sphx-glr-thumbnail-title">JaxSim as a multibody dynamics library</div>
</div>
<div class="sphx-glr-thumbcontainer" tooltip="JaxSim for developing closed-loop robot controllers">
.. only:: html
:doc:`_collections/examples/jaxsim_for_robot_controllers`
.. raw:: html
<div class="sphx-glr-thumbnail-title">JaxSim for developing closed-loop robot controllers</div>
</div>
</div>
================================================
FILE: docs/guide/configuration.rst
================================================
Configuration
=============
JaxSim utilizes environment variables for application configuration. Below is a detailed overview of the various configuration categories and their respective variables.
Collision Dynamics
~~~~~~~~~~~~~~~~~~
Environment variables starting with ``JAXSIM_COLLISION_`` are used to configure collision dynamics. The available variables are:
- ``JAXSIM_COLLISION_SPHERE_POINTS``: Specifies the number of collision points to approximate the sphere.
*Default:* ``50``.
- ``JAXSIM_COLLISION_MESH_ENABLED``: Enables or disables mesh-based collision detection.
*Default:* ``False``.
- ``JAXSIM_COLLISION_USE_BOTTOM_ONLY``: Limits collision detection to only the bottom half of the box or sphere.
*Default:* ``False``.
.. note::
The bottom half is defined as the half of the box or sphere with the lowest z-coordinate in the collision link frame.
Testing
~~~~~~~
For testing configurations, environment variables beginning with ``JAXSIM_TEST_`` are used. The following variables are available:
- ``JAXSIM_TEST_SEED``: Defines the seed for the random number generator.
*Default:* ``0``.
- ``JAXSIM_TEST_AD_ORDER``: Specifies the gradient order for automatic differentiation tests.
*Default:* ``1``.
- ``JAXSIM_TEST_FD_STEP_SIZE``: Sets the step size for finite difference tests.
*Default:* the cube root of the machine epsilon.
Joint Dynamics
~~~~~~~~~~~~~~
Joint dynamics are configured using environment variables starting with ``JAXSIM_JOINT_``. Available variables include:
- ``JAXSIM_JOINT_POSITION_LIMIT_DAMPER``: Overrides the damper value for joint position limits of the SDF model.
- ``JAXSIM_JOINT_POSITION_LIMIT_SPRING``: Overrides the spring value for joint position limits of the SDF model.
Logging and Exceptions
~~~~~~~~~~~~~~~~~~~~~~
The logging and exceptions configurations is controlled by the following environment variables:
- ``JAXSIM_LOGGING_LEVEL``: Determines the logging level.
*Default:* ``DEBUG`` for development, ``WARNING`` for production.
- ``JAXSIM_ENABLE_EXCEPTIONS``: Enables the runtime checks and exceptions. Note that enabling exceptions might lead to device-to-host transfer of data, increasing the computational time required.
*Default:* ``False``.
.. note::
Runtime exceptions are disabled by default on TPU.
================================================
FILE: docs/guide/install.rst
================================================
Installation
============
.. _installation:
Prerequisites
-------------
JAXsim requires Python 3.11 or later.
Basic Installation
------------------
You can install the project with using `conda`_:
.. code-block:: bash
conda install jaxsim -c conda-forge
Alternatively, you can use `pypa/pip`_, preferably in a `virtual environment`_:
.. code-block:: bash
pip install jaxsim
Have a look to `pyproject.toml`_ for a complete list of optional dependencies.
You can install all by using ``pip install "jaxsim[all]"``.
.. note::
If you need GPU support, please follow the official `installation instruction`_ of JAX.
.. _conda: https://anaconda.org/
.. _pyproject.toml: https://github.com/gbionics/jaxsim/blob/main/pyproject.toml
.. _pypa/pip: https://github.com/pypa/pip/
.. _virtual environment: https://docs.python.org/3.8/tutorial/venv.html
.. _installation instruction: https://github.com/google/jax/#installation
================================================
FILE: docs/index.rst
================================================
JAXsim
#######
A scalable physics engine and multibody dynamics library implemented with JAX. With JIT batteries 🔋
.. note::
This simulator currently focuses on locomotion applications. Only contacts with ground are supported.
Features
--------
.. grid::
.. grid-item::
:columns: 12 12 12 6
.. card:: Performance
:class-card: sd-border-0
:shadow: none
:class-title: sd-fs-5
.. div:: sd-font-normal
Physics engine in reduced coordinates implemented with JAX_.
Compatibility with JIT compilation for increased performance and transparent support to execute logic on CPUs, GPUs, and TPUs.
Parallel multi-body simulations on hardware accelerators for significantly increased throughput
.. grid-item::
:columns: 12 12 12 6
.. card:: Model Parsing
:class-card: sd-border-0
:shadow: none
:class-title: sd-fs-5
.. div:: sd-font-normal
Support for SDF models (and, upon conversion, URDF models). Revolute, prismatic, and fixed joints supported.
.. grid-item::
:columns: 12 12 12 6
.. card:: Automatic Differentiation
:class-card: sd-border-0
:shadow: none
:class-title: sd-fs-5
.. div:: sd-font-normal
Support for automatic differentiation of rigid body dynamics algorithms (RBDAs) for model-based robotics research.
Soft contacts model supporting full friction cone and sticking / slipping transition.
.. grid-item::
:columns: 12 12 12 6
.. card:: Complex Dynamics
:class-card: sd-border-0
:shadow: none
:class-title: sd-fs-5
.. div:: sd-font-normal
JAXsim provides a variety of integrators for the simulation of multibody dynamics, including RK4, Heun, Euler, and more.
Support of `multiple velocities representations <https://research.tue.nl/en/publications/multibody-dynamics-notation-version-2>`_.
----
.. toctree::
:hidden:
guide/install
guide/configuration
examples
.. toctree::
:hidden:
:maxdepth: 2
:caption: JAXsim API
modules/api
modules/math
modules/mujoco
modules/parsers
modules/rbda
modules/typing
modules/utils
Examples
--------
Explore and learn how to use the library through practical demonstrations available in the `examples <https://github.com/gbionics/jaxsim/tree/main/examples>`__ folder.
Credits
-------
The physics module of JAXsim is based on the theory of the `Rigid Body Dynamics Algorithms <https://link.springer.com/book/10.1007/978-1-4899-7560-7>`_ book by Roy Featherstone.
We structured part of our logic following its accompanying `code <http://royfeatherstone.org/spatial/index.html#spatial-software>`_.
The physics engine is developed entirely in Python using JAX_.
The inspiration for developing JAXsim originally stemmed from early versions of Brax_.
Here below we summarize the differences between the projects:
- JAXsim simulates multibody dynamics in reduced coordinates, while :code:`brax v1` uses maximal coordinates.
- The new v2 APIs of brax (and the new MJX_) were then implemented in reduced coordinates, following an approach comparable to JAXsim, with major differences in contact handling.
- The rigid-body algorithms used in JAXsim allow to efficiently compute quantities based on the Euler-Poincarè
formulation of the equations of motion, necessary for model-based robotics research.
- JAXsim supports SDF (and, indirectly, URDF) models, assuming the model is described with the
recent `Pose Frame Semantics <http://sdformat.org/tutorials?tut=pose_frame_semantics>`_.
- Contrarily to brax, JAXsim only supports collision detection between bodies and a compliant ground surface.
- The RBDAs of JAXsim support automatic differentiation, but this functionality has not been thoroughly tested.
People
------
Authors
'''''''
`Diego Ferigo <https://github.com/diegoferigo>`_
`Filippo Luca Ferretti <https://github.com/flferretti>`_
Maintainers
'''''''''''
`Filippo Luca Ferretti <https://github.com/flferretti>`_
`Alessandro Croci <https://github.com/xela-95>`_
License
-------
`BSD3 <https://choosealicense.com/licenses/bsd-3-clause/>`_
.. _Brax: https://github.com/google/brax
.. _MJX: https://mujoco.readthedocs.io/en/3.0.0/mjx.html
.. _JAX: https://github.com/google/jax
================================================
FILE: docs/make.bat
================================================
@ECHO OFF
pushd %~dp0
REM
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
set SPHINXPROJ=JaxSim
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd
================================================
FILE: docs/modules/api.rst
================================================
Functional API
==============
.. currentmodule:: jaxsim.api
.. autosummary::
:toctree: _autosummary
model
data
contact
kin_dyn_parameters
integrators
joint
link
frame
com
ode
references
actuation_model
common
Model
~~~~~
.. automodule:: jaxsim.api.model
:members:
:no-index:
.. automodule:: jaxsim.api.actuation_model
:members:
:no-index:
Data
~~~~
.. automodule:: jaxsim.api.data
:members:
:no-index:
Contact
~~~~~~~
.. automodule:: jaxsim.api.contact
:members:
:no-index:
KinDynParameters
~~~~~~~~~~~~~~~~
.. automodule:: jaxsim.api.kin_dyn_parameters
:members:
:no-index:
Joint
~~~~~
.. automodule:: jaxsim.api.joint
:members:
:no-index:
Link
~~~~~
.. automodule:: jaxsim.api.link
:members:
:no-index:
Frame
~~~~~
.. automodule:: jaxsim.api.frame
:members:
:no-index:
CoM
~~~
.. automodule:: jaxsim.api.com
:members:
:no-index:
Integration
~~~~~~~~~~~
.. automodule:: jaxsim.api.integrators
:members:
:no-index:
.. automodule:: jaxsim.api.ode
:members:
:no-index:
References
~~~~~~~~~~
.. automodule:: jaxsim.api.references
:members:
:no-index:
Common
~~~~~~
.. autoclass:: jaxsim.api.common.VelRepr
:members:
.. autoclass:: jaxsim.api.common.ModelDataWithVelocityRepresentation
:members:
================================================
FILE: docs/modules/math.rst
================================================
Math
====
.. currentmodule:: jaxsim.math
.. automodule:: jaxsim.math.adjoint
:members:
:undoc-members:
.. automodule:: jaxsim.math.cross
:members:
:undoc-members:
.. automodule:: jaxsim.math.inertia
:members:
:undoc-members:
.. automodule:: jaxsim.math.quaternion
:members:
:undoc-members:
.. automodule:: jaxsim.math.rotation
:members:
:undoc-members:
.. automodule:: jaxsim.math.skew
:members:
:undoc-members:
================================================
FILE: docs/modules/mujoco.rst
================================================
MuJoCo Visualizer
==================
JAXsim provides a simple interface with MuJoCo's visualizer. The visualizer is
a separate process that communicates with the main simulation process. This
allows for the simulation to run at full speed while the visualizer can run at
a different frame rate.
.. currentmodule:: jaxsim.mujoco
Loaders
~~~~~~~
.. automodule:: jaxsim.mujoco.loaders
:members:
Model
~~~~~
.. automodule:: jaxsim.mujoco.model
:members:
Visualizer
~~~~~~~~~~
.. automodule:: jaxsim.mujoco.visualizer
:members:
================================================
FILE: docs/modules/parsers.rst
================================================
Parsers
=======
.. automodule:: jaxsim.parsers.descriptions.collision
:members:
.. automodule:: jaxsim.parsers.descriptions.joint
:members:
.. automodule:: jaxsim.parsers.descriptions.link
:members:
.. automodule:: jaxsim.parsers.descriptions.model
:members:
================================================
FILE: docs/modules/rbda.rst
================================================
Rigid Body Dynamics Algorithms
==============================
This module provides a set of algorithms for rigid body dynamics.
.. currentmodule:: jaxsim.rbda
.. autosummary::
:toctree: _autosummary
aba
collidable_points
contacts.soft
contacts.rigid
contacts.relaxed_rigid
crba
forward_kinematics
jacobian
utils
Collision Detection
~~~~~~~~~~~~~~~~~~~
.. automodule:: jaxsim.rbda.collidable_points
:members:
:no-index:
Contact Models
~~~~~~~~~~~~~~
.. automodule:: jaxsim.rbda.contacts.soft
:members:
:no-index:
.. automodule:: jaxsim.rbda.contacts.rigid
:members:
:no-index:
.. automodule:: jaxsim.rbda.contacts.relaxed_rigid
:members:
:no-index:
Utilities
~~~~~~~~~
.. automodule:: jaxsim.rbda.utils
:members:
:no-index:
================================================
FILE: docs/modules/typing.rst
================================================
Typing
======
.. currentmodule:: jaxsim.typing
.. autosummary::
PyTree
Matrix
Bool
Int
Float
Vector
BoolLike
FloatLike
IntLike
ArrayLike
VectorLike
MatrixLike
================================================
FILE: docs/modules/utils.rst
================================================
Utils
=====
.. automodule:: jaxsim.utils
:members:
:inherited-members:
.. autoclass:: jaxsim.utils.JaxsimDataclass
:members:
:inherited-members:
================================================
FILE: environment.yml
================================================
name: jaxsim
channels:
- conda-forge
dependencies:
# ===========================
# Dependencies from setup.cfg
# ===========================
- python >= 3.12.0
- coloredlogs
- jax >= 0.4.34
- jaxlib >= 0.4.34
- jaxlie >= 1.3.0
- jax-dataclasses >= 1.4.0
- optax >= 0.2.3
- pptree
- qpax
- rod >= 0.3.3
- trimesh
- typing_extensions # python<3.12
# ====================================
# Optional dependencies from setup.cfg
# ====================================
# [testing]
- chex
- idyntree >= 12.2.1
- pytest
- pytest-benchmark
- pytest-icdiff
- robot_descriptions >= 1.16.0
- icub-models
# [viz]
- lxml
- mediapy
- mujoco >= 3.0.0
- scipy >= 1.14.0
# ==========================
# Documentation dependencies
# ==========================
- cachecontrol
- filecache
- jinja2
- myst-nb
- pip
- sphinx
- sphinx-autodoc-typehints
- sphinx-book-theme
- sphinx-copybutton
- sphinx-design
- sphinx_fontawesome
- sphinx-gallery
- sphinx-jinja2-compat
- sphinx-multiversion
- sphinx_rtd_theme
- sphinx-toolbox
- icub-models
# ========================================
# Other dependencies for GitHub Codespaces
# ========================================
- ipython
- pip:
- sphinx-collections # TODO (flferretti): PR to conda-forge
================================================
FILE: examples/.gitattributes
================================================
# GitHub syntax highlighting
pixi.lock linguist-language=YAML
================================================
FILE: examples/.gitignore
================================================
# pixi environments
.pixi
================================================
FILE: examples/README.md
================================================
# JaxSim Examples
This folder contains Jupyter notebooks that demonstrate the practical usage of JaxSim.
## Featured examples
| Notebook | Google Colab | Description |
| :--- | :---: | :--- |
| [`jaxsim_as_multibody_dynamics_library`](./jaxsim_as_multibody_dynamics_library.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_as_multibody_dynamics] | An example demonstrating how to use JaxSim as a multibody dynamics library. |
| [`jaxsim_as_physics_engine.ipynb`](./jaxsim_as_physics_engine.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_as_physics_engine] | An example demonstrating how to simulate vectorized models in parallel. |
| [`jaxsim_as_physics_engine_advanced.ipynb`](./jaxsim_as_physics_engine_advanced.ipynb) | [![Open In Colab][colab_badge]][jaxsim_as_physics_engine_advanced] | An example showcasing advanced JaxSim usage, such as customizing the integrator, contact model, and more. |
| [`jaxsim_for_robot_controllers.ipynb`](./jaxsim_for_robot_controllers.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_closed_loop] | A basic example showing how to simulate a PD controller with gravity compensation for a 2-DOF cart-pole. |
[colab_badge]: https://colab.research.google.com/assets/colab-badge.svg
[ipynb_jaxsim_closed_loop]: https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_for_robot_controllers.ipynb
[ipynb_jaxsim_as_physics_engine]: https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_as_physics_engine.ipynb
[jaxsim_as_physics_engine_advanced]: https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_as_physics_engine_advanced.ipynb
[ipynb_jaxsim_as_multibody_dynamics]: https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_as_multibody_dynamics_library.ipynb
## How to run the examples
You can run the JaxSim examples with hardware acceleration in two ways.
### Option 1: Google Colab (recommended)
The easiest way is to use the provided Google Colab links to run the notebooks in a hosted environment
with no setup required.
### Option 2: Local execution with `pixi`
To run the examples locally, first install `pixi` following the [official documentation][pixi_installation]:
[pixi_installation]: https://pixi.sh/#installation
```bash
curl -fsSL https://pixi.sh/install.sh | bash
```
Then, from the repository's root directory, execute the example notebooks using:
```bash
pixi run examples
```
This command will automatically handle all necessary dependencies and run the examples in a self-contained environment.
================================================
FILE: examples/assets/build_cartpole_urdf.py
================================================
import os
if "ROD_LOGGING_LEVEL" not in os.environ:
os.environ["ROD_LOGGING_LEVEL"] = "WARNING"
import numpy as np
import rod.kinematics.tree_transforms
from rod.builder import primitives
if __name__ == "__main__":
# ================
# Model parameters
# ================
# Rail parameters.
rail_height = 1.2
rail_length = 5.0
rail_radius = 0.005
rail_mass = 5.0
# Cart parameters.
cart_mass = 1.0
cart_size = (0.1, 0.2, 0.05)
# Pole parameters.
pole_mass = 0.5
pole_length = 1.0
pole_radius = 0.005
# ========================
# Create the link builders
# ========================
rail_builder = primitives.CylinderBuilder(
name="rail",
mass=rail_mass,
radius=rail_radius,
length=rail_length,
)
cart_builder = primitives.BoxBuilder(
name="cart",
mass=cart_mass,
x=cart_size[0],
y=cart_size[1],
z=cart_size[2],
)
pole_builder = primitives.CylinderBuilder(
name="pole",
mass=pole_mass,
radius=pole_radius,
length=pole_length,
)
# =================
# Create the joints
# =================
world_to_rail = rod.Joint(
name="world_to_rail",
type="fixed",
parent="world",
child=rail_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(
relative_to="world",
),
)
linear = rod.Joint(
name="linear",
type="prismatic",
parent=rail_builder.name,
child=cart_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(
relative_to=rail_builder.name,
pos=np.array([0, 0, rail_height]),
),
axis=rod.Axis(
xyz=rod.Xyz(xyz=[0, 1, 0]),
limit=rod.Limit(
upper=(rail_length / 2 - cart_size[1] / 2),
lower=-(rail_length / 2 - cart_size[1] / 2),
effort=500.0,
velocity=10.0,
),
),
)
pivot = rod.Joint(
name="pivot",
type="continuous",
parent=cart_builder.name,
child=pole_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(
relative_to=cart_builder.name,
),
axis=rod.Axis(
xyz=rod.Xyz(xyz=[1, 0, 0]),
limit=rod.Limit(),
),
)
# ================
# Create the links
# ================
rail_elements_pose = primitives.PrimitiveBuilder.build_pose(
pos=np.array([0, 0, rail_height]),
rpy=np.array([np.pi / 2, 0, 0]),
)
rail = (
rail_builder.build_link(
name=rail_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(
relative_to=world_to_rail.name,
),
)
.add_inertial(pose=rail_elements_pose)
.add_visual(pose=rail_elements_pose)
.add_collision(pose=rail_elements_pose)
.build()
)
cart = (
cart_builder.build_link(
name=cart_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(relative_to=linear.name),
)
.add_inertial()
.add_visual()
.add_collision()
.build()
)
pole_elements_pose = primitives.PrimitiveBuilder.build_pose(
pos=np.array([0, 0, pole_length / 2]),
)
pole = (
pole_builder.build_link(
name=pole_builder.name,
pose=primitives.PrimitiveBuilder.build_pose(
relative_to=pivot.name,
),
)
.add_inertial(pose=pole_elements_pose)
.add_visual(pose=pole_elements_pose)
.add_collision(pose=pole_elements_pose)
.build()
)
# ===========
# Build model
# ===========
# Create ROD in-memory model.
model = rod.Model(
name="cartpole",
canonical_link=rail.name,
link=[
rail,
cart,
pole,
],
joint=[
world_to_rail,
linear,
pivot,
],
)
# Update the pose elements to be closer to those expected in URDF.
model.switch_frame_convention(
frame_convention=rod.FrameConvention.Urdf, explicit_frames=True
)
# ==============
# Get SDF string
# ==============
# Create the top-level SDF object.
sdf = rod.Sdf(version="1.10", model=model)
# Generate the SDF string.
# sdf_string = sdf.serialize(pretty=True, validate=True)
# ===============
# Get URDF string
# ===============
import rod.urdf.exporter
# Convert the SDF to URDF.
urdf_string = rod.urdf.exporter.UrdfExporter(
pretty=True, indent=" "
).to_urdf_string(sdf=sdf)
# Print the URDF string.
print(urdf_string)
================================================
FILE: examples/assets/cartpole.urdf
================================================
<?xml version="1.0" encoding="utf-8"?>
<robot name="cartpole">
<link name="world"/>
<link name="rail">
<inertial>
<origin xyz="0.0 0.0 1.2" rpy="1.5707963267948963 0.0 0.0"/>
<mass value="5.0"/>
<inertia ixx="10.416697916666665" ixy="0.0" ixz="0.0" iyy="10.416697916666665" iyz="0.0" izz="6.25e-05"/>
</inertial>
<visual name="rail_visual">
<origin xyz="0.0 0.0 1.2" rpy="1.5707963267948963 0.0 0.0"/>
<geometry>
<cylinder radius="0.005" length="5.0"/>
</geometry>
</visual>
<collision name="rail_collision">
<origin xyz="0.0 0.0 1.2" rpy="1.5707963267948963 0.0 0.0"/>
<geometry>
<cylinder radius="0.005" length="5.0"/>
</geometry>
</collision>
</link>
<link name="cart">
<inertial>
<origin xyz="0.0 0.0 0.0" rpy="0.0 0.0 0.0"/>
<mass value="1.0"/>
<inertia ixx="0.0035416666666666674" ixy="0.0" ixz="0.0" iyy="0.0010416666666666669" iyz="0.0" izz="0.0041666666666666675"/>
</inertial>
<visual name="cart_visual">
<origin xyz="0.0 0.0 0.0" rpy="0.0 0.0 0.0"/>
<geometry>
<box size="0.1 0.2 0.05"/>
</geometry>
</visual>
<collision name="cart_collision">
<origin xyz="0.0 0.0 0.0" rpy="0.0 0.0 0.0"/>
<geometry>
<box size="0.1 0.2 0.05"/>
</geometry>
</collision>
</link>
<link name="pole">
<inertial>
<origin xyz="0.0 0.0 0.5" rpy="0.0 0.0 0.0"/>
<mass value="0.5"/>
<inertia ixx="0.04166979166666667" ixy="0.0" ixz="0.0" iyy="0.04166979166666667" iyz="0.0" izz="6.25e-06"/>
</inertial>
<visual name="pole_visual">
<origin xyz="0.0 0.0 0.5" rpy="0.0 0.0 0.0"/>
<geometry>
<cylinder radius="0.005" length="1.0"/>
</geometry>
</visual>
<collision name="pole_collision">
<origin xyz="0.0 0.0 0.5" rpy="0.0 0.0 0.0"/>
<geometry>
<cylinder radius="0.005" length="1.0"/>
</geometry>
</collision>
</link>
<link name="cart_frame"/>
<link name="rail_frame"/>
<joint name="cart_frame_joint" type="fixed">
<parent link="cart" />
<child link="cart_frame" />
<origin xyz="0.0 0.0 0.0" rpy="0.0 0.0 0.0" />
</joint>
<joint name="rail_frame_joint" type="fixed">
<parent link="rail" />
<child link="rail_frame" />
<origin xyz="0.0 0.0 1.2" rpy="0.0 0.0 0.0" />
</joint>
<joint name="world_to_rail" type="fixed">
<origin xyz="0.0 0.0 0.0" rpy="0.0 0.0 0.0"/>
<parent link="world"/>
<child link="rail"/>
</joint>
<joint name="linear" type="prismatic">
<origin xyz="0.0 0.0 1.2" rpy="0.0 0.0 0.0"/>
<parent link="rail"/>
<child link="cart"/>
<axis xyz="0 1 0"/>
<limit effort="500.0" velocity="10.0" lower="-2.4" upper="2.4"/>
</joint>
<joint name="pivot" type="continuous">
<origin xyz="0.0 0.0 0.0" rpy="0.0 0.0 0.0"/>
<parent link="cart"/>
<child link="pole"/>
<axis xyz="1 0 0"/>
<limit effort="3.4028235e+38" velocity="3.4028235e+38"/>
</joint>
</robot>
================================================
FILE: examples/jaxsim_as_multibody_dynamics_library.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "DpLq0-lltwZ1"
},
"source": [
"# `JaxSim` as a multibody dynamics library\n",
"\n",
"JaxSim was initially developed as a **hardware-accelerated physics engine**. Over time, it has evolved, adding new features to become a comprehensive **JAX-based multibody dynamics library**.\n",
"\n",
"In this notebook, you'll explore the main APIs for loading robot models and computing key quantities for applications such as control, planning, and more.\n",
"\n",
"A key advantage of JaxSim is its ability to create fully differentiable closed-loop systems, enabling end-to-end optimization. Combined with the flexibility to parameterize model kinematics and dynamics, JaxSim can serve as an excellent playground for robot learning applications.\n",
"\n",
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_as_multibody_dynamics_library.ipynb\">\n",
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
"</a>\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rcEwprINtwZ3"
},
"source": [
"## Prepare environment\n",
"\n",
"First, we need to install the necessary packages and import their resources."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "u4xL7dbBtwZ3",
"outputId": "1a088e28-e005-4910-928c-cb641e589ab5"
},
"outputs": [],
"source": [
"# @title Imports and setup\n",
"from IPython.display import clear_output\n",
"import sys\n",
"\n",
"IS_COLAB = \"google.colab\" in sys.modules\n",
"\n",
"# Install JAX, sdformat, and other notebook dependencies.\n",
"if IS_COLAB:\n",
" !{sys.executable} -m pip install --pre -qU jaxsim\n",
" !{sys.executable} -m pip install robot_descriptions>=1.16.0\n",
" !apt install -qq lsb-release wget gnupg\n",
" !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n",
" !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n",
" !apt -qq update\n",
" !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n",
"\n",
" clear_output()\n",
"\n",
"import os\n",
"import pathlib\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jaxsim.api as js\n",
"import jaxsim.math\n",
"from jaxsim import logging\n",
"from jaxsim import VelRepr\n",
"\n",
"logging.set_logging_level(logging.LoggingLevel.WARNING)\n",
"print(f\"Running on {jax.devices()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fN8Xg4QgtwZ4"
},
"source": [
"## Robot model\n",
"\n",
"JaxSim allows loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files.\n",
"\n",
"In this example, we will use the [ErgoCub][ergocub] humanoid robot model. If you have a URDF/SDF file for your robot that is compatible with [`gazebosim/sdformat`][sdformat_github][1], it should work out-of-the-box with JaxSim.\n",
"\n",
"[sdformat]: http://sdformat.org/\n",
"[urdf]: http://wiki.ros.org/urdf/\n",
"[ergocub]: https://ergocub.eu/\n",
"[sdformat_github]: https://github.com/gazebosim/sdformat\n",
"\n",
"---\n",
"\n",
"[1]: JaxSim validates robot descriptions using the command `gz sdf -p /path/to/file.urdf`. Ensure this command runs successfully on your file.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rB0BFxyPtwZ5"
},
"outputs": [],
"source": [
"# @title Fetch the URDF file\n",
"\n",
"try:\n",
" os.environ[\"ROBOT_DESCRIPTION_COMMIT\"] = \"v0.7.7\"\n",
"\n",
" import robot_descriptions.ergocub_description\n",
"\n",
"finally:\n",
" _ = os.environ.pop(\"ROBOT_DESCRIPTION_COMMIT\", None)\n",
"\n",
"model_description_path = pathlib.Path(\n",
" robot_descriptions.ergocub_description.URDF_PATH.replace(\n",
" \"ergoCubSN002\", \"ergoCubSN001\"\n",
" )\n",
")\n",
"\n",
"clear_output()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jeTUZic8twZ5"
},
"source": [
"### Create the model and its data\n",
"\n",
"The dynamics of a generic floating-base model are governed by the following equations of motion:\n",
"\n",
"$$\n",
"M(\\mathbf{q}) \\dot{\\boldsymbol{\\nu}} + \\mathbf{h}(\\mathbf{q}, \\boldsymbol{\\nu}) = B \\boldsymbol{\\tau} + \\sum_{L_i \\in \\mathcal{L}} J_{W,L_i}^\\top(\\mathbf{q}) \\: \\mathbf{f}_i\n",
".\n",
"$$\n",
"\n",
"Here, the system state is represented by:\n",
"\n",
"- $\\mathbf{q} = ({}^W \\mathbf{p}_B, \\, \\mathbf{s}) \\in \\text{SE}(3) \\times \\mathbb{R}^n$ is the generalized position.\n",
"- $\\boldsymbol{\\nu} = (\\boldsymbol{v}_{W,B}, \\, \\boldsymbol{\\omega}_{W,B}, \\, \\dot{\\mathbf{s}}) \\in \\mathbb{R}^{6+n}$ is the generalized velocity.\n",
"\n",
"The inputs to the system are:\n",
"\n",
"- $\\boldsymbol{\\tau} \\in \\mathbb{R}^n$ are the joint torques.\n",
"- $\\mathbf{f}_i \\in \\mathbb{R}^6$ is the 6D force applied to the link $L_i$.\n",
"\n",
"JaxSim exposes functional APIs to operate over the following two main data structures:\n",
"\n",
"- **`JaxSimModel`** stores all the constant information parsed from the model description.\n",
"- **`JaxSimModelData`** holds the state of model.\n",
"\n",
"Additionally, JaxSim includes a utility class, **`JaxSimModelReferences`**, for managing and manipulating system inputs.\n",
"\n",
"---\n",
"\n",
"This notebook uses the notation summarized in the following report. Please refer to this document if you have any questions or if something is unclear.\n",
"\n",
"> Traversaro and Saccon, **Multibody dynamics notation**, 2019, [URL](https://pure.tue.nl/ws/portalfiles/portal/139293126/A_Multibody_Dynamics_Notation_Revision_2_.pdf)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WYgBAxU0twZ6"
},
"outputs": [],
"source": [
"# Create the model from the model description.\n",
"# JaxSim removes all fixed joints by lumping together their parent and child links.\n",
"full_model = js.model.JaxSimModel.build_from_model_description(\n",
" model_description=model_description_path\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DdaETmDStwZ6"
},
"source": [
"It is often useful to work with only a subset of joints, referred to as the _considered joints_. JaxSim allows to reduce a model so that the computation of the rigid body dynamics quantities is simplified.\n",
"\n",
"By default, the positions of the removed joints are considered to be zero. If this is not the case, the `reduce` function accepts a dictionary `dict[str, float]` to specify custom joint positions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QuhG7Zv5twZ7"
},
"outputs": [],
"source": [
"model = js.model.reduce(\n",
" model=full_model,\n",
" considered_joints=tuple(\n",
" j\n",
" for j in full_model.joint_names()\n",
" # Remove sensor joints.\n",
" if \"camera\" not in j\n",
" # Remove head and hands.\n",
" and \"neck\" not in j\n",
" and \"wrist\" not in j\n",
" and \"thumb\" not in j\n",
" and \"index\" not in j\n",
" and \"middle\" not in j\n",
" and \"ring\" not in j\n",
" and \"pinkie\" not in j\n",
" # Remove upper body.\n",
" and \"torso\" not in j and \"elbow\" not in j and \"shoulder\" not in j\n",
" ),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RLvAit_i2ZiA",
"outputId": "ea3954af-b9b9-46ac-d9cb-20b99b1eac94"
},
"outputs": [],
"source": [
"# Print model quantities.\n",
"print(f\"Model name: {model.name()}\")\n",
"print(f\"Number of links: {model.number_of_links()}\")\n",
"print(f\"Number of joints: {model.number_of_joints()}\")\n",
"\n",
"print()\n",
"print(f\"Links:\\n{model.link_names()}\")\n",
"\n",
"print()\n",
"print(f\"Joints:\\n{model.joint_names()}\")\n",
"\n",
"print()\n",
"print(f\"Frames:\\n{model.frame_names()}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Xp8V5on5twZ8",
"outputId": "cc1564db-ae91-4dba-92c9-b8b87bd65f10"
},
"outputs": [],
"source": [
"# Create a random data object from the reduced model.\n",
"data = js.data.random_model_data(model=model)\n",
"\n",
"# Print the default state.\n",
"W_H_B, s = data.generalized_position\n",
"ν = data.generalized_velocity\n",
"\n",
"print(f\"W_H_B: shape={W_H_B.shape}\\n{W_H_B}\\n\")\n",
"print(f\"s: shape={s.shape}\\n{s}\\n\")\n",
"print(f\"ν: shape={ν.shape}\\n{ν}\\n\") # noqa: RUF001"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "XLx3sv9VtwZ9",
"outputId": "28f5f070-e37e-464e-d84e-2944cfdc28dc"
},
"outputs": [],
"source": [
"# Create a random link forces matrix.\n",
"link_forces = jax.random.uniform(\n",
" minval=-10.0,\n",
" maxval=10.0,\n",
" shape=(model.number_of_links(), 6),\n",
" key=jax.random.PRNGKey(0),\n",
")\n",
"\n",
"# Create a random joint force references vector.\n",
"# Note that these are called 'references' because the actual joint forces that\n",
"# are actuated might differ due to effects like joint friction.\n",
"joint_force_references = jax.random.uniform(\n",
" minval=-10.0, maxval=10.0, shape=(model.dofs(),), key=jax.random.PRNGKey(0)\n",
")\n",
"\n",
"# Create the references object.\n",
"references = js.references.JaxSimModelReferences.build(\n",
" model=model,\n",
" data=data,\n",
" link_forces=link_forces,\n",
" joint_force_references=joint_force_references,\n",
")\n",
"\n",
"print(f\"link_forces: shape={references.link_forces(model=model, data=data).shape}\")\n",
"print(f\"joint_force_references: shape={references.joint_force_references(model=model).shape}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AaG817vP4LfT"
},
"source": [
"## Robot Kinematics\n",
"\n",
"JaxSim offers functional APIs for computing kinematic quantities:\n",
"\n",
"- **`jaxsim.api.model`**: vectorized functions operating on the whole model.\n",
"- **`jaxsim.api.link`**: functions operating on individual links.\n",
"- **`jaxsim.api.frame`**: functions operating on individual frames. \n",
"\n",
"Due to JAX limitations on vectorizable data types, many APIs operate on indices instead of names. Since using indices can be error prone, JaxSim provides conversion functions for both links:\n",
"\n",
"- **jaxsim.api.link.names_to_idxs()**\n",
"- **jaxsim.api.link.idxs_to_names()**\n",
"\n",
"and frames: \n",
"\n",
"- **jaxsim.api.frame.names_to_idxs()**\n",
"- **jaxsim.api.frame.idxs_to_names()**\n",
"\n",
"We recommend using names whenever possible to avoid hard-to-trace errors.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QxImwfZz7pz-"
},
"outputs": [],
"source": [
"# Find the index of a link.\n",
"link_name = \"l_ankle_2\"\n",
"link_index = js.link.name_to_idx(model=model, link_name=link_name)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "C22Iqu2i4G-I",
"outputId": "94376151-177d-410f-f375-b7b8bd080992"
},
"outputs": [],
"source": [
"# @title Link Pose\n",
"\n",
"# Compute its pose w.r.t. the world frame through forward kinematics.\n",
"W_H_L = js.link.transform(model=model, data=data, link_index=link_index)\n",
"\n",
"print(f\"Transform of '{link_name}': shape={W_H_L.shape}\\n{W_H_L}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DnSpE_f97RkX",
"outputId": "a3f6b535-4ae5-49f4-8921-7fe4dda5debb"
},
"outputs": [],
"source": [
"# @title Link 6D Velocity\n",
"\n",
"# JaxSim allows to select the so-called representation of the frame velocity.\n",
"L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Body)\n",
"LW_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Mixed)\n",
"W_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Inertial)\n",
"\n",
"print(f\"Body-fixed velocity L_v_WL={L_v_WL}\")\n",
"print(f\"Mixed velocity: LW_v_WL={LW_v_WL}\")\n",
"print(f\"Inertial-fixed velocity: W_v_WL={W_v_WL}\")\n",
"\n",
"# These can also be computed passing through the link free-floating Jacobian.\n",
"# This type of Jacobian has a input velocity representation that corresponds\n",
"# the velocity representation of ν, and an output velocity representation that\n",
"# corresponds to the velocity representation of the desired 6D velocity.\n",
"\n",
"# You can use the following context manager to easily switch between representations.\n",
"with data.switch_velocity_representation(VelRepr.Body):\n",
"\n",
" # Body-fixed generalized velocity.\n",
" B_ν = data.generalized_velocity\n",
"\n",
" # Free-floating Jacobian accepting a body-fixed generalized velocity and\n",
" # returning an inertial-fixed link velocity.\n",
" W_J_WL_B = js.link.jacobian(\n",
" model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Inertial\n",
" )\n",
"\n",
"# Now the following relation should hold.\n",
"assert jnp.allclose(W_v_WL, W_J_WL_B @ B_ν)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SSoziCShtwZ9"
},
"outputs": [],
"source": [
"# Find the index of a frame.\n",
"frame_name = \"l_foot_front\"\n",
"frame_index = js.frame.name_to_idx(model=model, frame_name=frame_name)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fVp_xP_1twZ9",
"outputId": "cfaa0569-d768-4708-c98c-a5867c056d04"
},
"outputs": [],
"source": [
"# @title Frame Pose\n",
"\n",
"# Compute its pose w.r.t. the world frame through forward kinematics.\n",
"W_H_F = js.frame.transform(model=model, data=data, frame_index=frame_index)\n",
"\n",
"print(f\"Transform of '{frame_name}': shape={W_H_F.shape}\\n{W_H_F}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QqaqxneEFYiW"
},
"outputs": [],
"source": [
"# @title Frame 6D Velocity\n",
"\n",
"# JaxSim allows to select the so-called representation of the frame velocity.\n",
"F_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Body)\n",
"FW_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Mixed)\n",
"W_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Inertial)\n",
"\n",
"print(f\"Body-fixed velocity F_v_WF={F_v_WF}\")\n",
"print(f\"Mixed velocity: FW_v_WF={FW_v_WF}\")\n",
"print(f\"Inertial-fixed velocity: W_v_WF={W_v_WF}\")\n",
"\n",
"# These can also be computed passing through the frame free-floating Jacobian.\n",
"# This type of Jacobian has a input velocity representation that corresponds\n",
"# the velocity representation of ν, and an output velocity representation that\n",
"# corresponds to the velocity representation of the desired 6D velocity.\n",
"\n",
"# You can use the following context manager to easily switch between representations.\n",
"with data.switch_velocity_representation(VelRepr.Body):\n",
"\n",
" # Body-fixed generalized velocity.\n",
" B_ν = data.generalized_velocity\n",
"\n",
" # Free-floating Jacobian accepting a body-fixed generalized velocity and\n",
" # returning an inertial-fixed link velocity.\n",
" W_J_WF_B = js.frame.jacobian(\n",
" model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Inertial\n",
" )\n",
"\n",
"# Now the following relation should hold.\n",
"assert jnp.allclose(W_v_WF, W_J_WF_B @ B_ν)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "d_vp6D74GoVZ",
"outputId": "798b9283-792e-4339-b56c-df2595fac974"
},
"source": [
"## Robot Dynamics\n",
"\n",
"JaxSim provides all the quantities involved in the equations of motion, restated here:\n",
"\n",
"$$\n",
"M(\\mathbf{q}) \\dot{\\boldsymbol{\\nu}} + \\mathbf{h}(\\mathbf{q}, \\boldsymbol{\\nu}) = B \\boldsymbol{\\tau} + \\sum_{L_i \\in \\mathcal{L}} J_{W,L_i}^\\top(\\mathbf{q}) \\: \\mathbf{f}_i\n",
".\n",
"$$\n",
"\n",
"Specifically, it can compute:\n",
"\n",
"- $M(\\mathbf{q}) \\in \\mathbb{R}^{(6+n)\\times(6+n)}$: the mass matrix.\n",
"- $\\mathbf{h}(\\mathbf{q}, \\boldsymbol{\\nu}) \\in \\mathbb{R}^{6+n}$: the vector of bias forces.\n",
"- $B \\in \\mathbb{R}^{(6+n) \\times n}$ the joint selector matrix.\n",
"- $J_{W,L} \\in \\mathbb{R}^{6 \\times (6+n)}$ the Jacobian of link $L$.\n",
"\n",
"Often, for convenience, link Jacobians are stacked together. Since JaxSim efficiently computes the Jacobians for all links, using the stacked version is recommended when needed:\n",
"\n",
"$$\n",
"M(\\mathbf{q}) \\dot{\\boldsymbol{\\nu}} + \\mathbf{h}(\\mathbf{q}, \\boldsymbol{\\nu}) = B \\boldsymbol{\\tau} + J_{W,\\mathcal{L}}^\\top(\\mathbf{q}) \\: \\mathbf{f}_\\mathcal{L}\n",
".\n",
"$$\n",
"\n",
"Furthermore, there are applications that require unpacking the vector of bias forces as follow:\n",
"\n",
"$$\n",
"\\mathbf{h}(\\mathbf{q}, \\boldsymbol{\\nu}) = C(\\mathbf{q}, \\boldsymbol{\\nu}) \\boldsymbol{\\nu} + \\mathbf{g}(\\mathbf{q})\n",
",\n",
"$$\n",
"\n",
"where:\n",
"\n",
"- $\\mathbf{g}(\\mathbf{q}) \\in \\mathbb{R}^{6+n}$: the vector of gravity forces.\n",
"- $C(\\mathbf{q}, \\boldsymbol{\\nu}) \\in \\mathbb{R}^{(6+n)\\times(6+n)}$: the Coriolis matrix.\n",
"\n",
"Here below we report the functions to compute all these quantities. Note that all quantities depend on the active velocity representation of `data`. As it was done for the link velocity, it is possible to change the representation associated to all the computed quantities by operating within the corresponding context manager. Here below we consider the default representation of data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oOKJOVfsH4Ki"
},
"outputs": [],
"source": [
"print(\"Velocity representation of data:\", data.velocity_representation, \"\\n\")\n",
"\n",
"# Compute the mass matrix.\n",
"M = js.model.free_floating_mass_matrix(model=model, data=data)\n",
"print(f\"M: shape={M.shape}\")\n",
"\n",
"# Compute the vector of bias forces.\n",
"h = js.model.free_floating_bias_forces(model=model, data=data)\n",
"print(f\"h: shape={h.shape}\")\n",
"\n",
"# Compute the vector of gravity forces.\n",
"g = js.model.free_floating_gravity_forces(model=model, data=data)\n",
"print(f\"g: shape={g.shape}\")\n",
"\n",
"# Compute the Coriolis matrix.\n",
"C = js.model.free_floating_coriolis_matrix(model=model, data=data)\n",
"print(f\"C: shape={C.shape}\")\n",
"\n",
"# Create a the joint selector matrix.\n",
"B = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T\n",
"print(f\"B: shape={B.shape}\")\n",
"\n",
"# Compute the stacked tensor of link Jacobians.\n",
"J = js.model.generalized_free_floating_jacobian(model=model, data=data)\n",
"print(f\"J: shape={J.shape}\")\n",
"\n",
"# Extract the joint forces from the references object.\n",
"τ = references.joint_force_references(model=model)\n",
"print(f\"τ: shape={τ.shape}\")\n",
"\n",
"# Extract the link forces from the references object.\n",
"f_L = references.link_forces(model=model, data=data)\n",
"print(f\"f_L: shape={f_L.shape}\")\n",
"\n",
"# The following relation should hold.\n",
"assert jnp.allclose(h, C @ ν + g)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "FlNo8dNWKKtu",
"outputId": "313e939b-f88f-4407-c9ee-b5b3b7443061"
},
"source": [
"### Forward Dynamics\n",
"\n",
"$$\n",
"\\dot{\\boldsymbol{\\nu}} = \\text{FD}(\\mathbf{q}, \\boldsymbol{\\nu}, \\boldsymbol{\\tau}, \\mathbf{f}_{\\mathcal{L}})\n",
"$$\n",
"\n",
"JaxSim provides two alternative methods to compute the forward dynamics:\n",
"\n",
"1. Operate on the quantities of the equations of motion.\n",
"2. Call the recursive Articulated Body Algorithm (ABA).\n",
"\n",
"The physics engine provided by JaxSim exploits the efficient calculation of the forward dynamics with ABA for simulating the trajectories of the system dynamics."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LXARuRu1Ly1K"
},
"outputs": [],
"source": [
"ν̇_eom = jnp.linalg.pinv(M) @ (B @ τ - h + jnp.einsum(\"l6g,l6->g\", J, f_L))\n",
"\n",
"v̇_WB, s̈ = js.model.forward_dynamics_aba(\n",
" model=model, data=data, link_forces=f_L, joint_forces=joint_force_references\n",
")\n",
"\n",
"ν̇_aba = jnp.hstack([v̇_WB, s̈])\n",
"print(f\"ν̇: shape={ν̇_aba.shape}\") # noqa: RUF001\n",
"\n",
"# The following relation should hold.\n",
"assert jnp.allclose(ν̇_eom, ν̇_aba)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "g5GOYXDnLySU",
"outputId": "ad4ce77d-d06f-473a-9c32-040680d76aa5"
},
"source": [
"### Inverse Dynamics\n",
"\n",
"$$\n",
"(\\boldsymbol{\\tau}, \\, \\mathbf{f}_B) = \\text{ID}(\\mathbf{q}, \\boldsymbol{\\nu}, \\dot{\\boldsymbol{\\nu}}, \\mathbf{f}_{\\mathcal{L}})\n",
"$$\n",
"\n",
"JaxSim offers two methods to compute inverse dynamics:\n",
"\n",
"- Directly use the quantities from the equations of motion.\n",
"- Use the Recursive Newton-Euler Algorithm (RNEA).\n",
"\n",
"Unlike many other implementations, JaxSim's RNEA for floating-base systems is the true inverse of $\\text{FD}$. It also computes the 6D force applied to the base link that generates the base acceleration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UTae5MjhaP2H"
},
"outputs": [],
"source": [
"f_B, τ_rnea = js.model.inverse_dynamics(\n",
" model=model,\n",
" data=data,\n",
" base_acceleration=v̇_WB,\n",
" joint_accelerations=s̈,\n",
" # To check that f_B works, let's remove the force applied\n",
" # to the base link from the link forces.\n",
" link_forces=f_L.at[0].set(jnp.zeros(6))\n",
")\n",
"\n",
"print(f\"f_B: shape={f_B.shape}\")\n",
"print(f\"τ_rnea: shape={τ_rnea.shape}\")\n",
"\n",
"# The following relations should hold.\n",
"assert jnp.allclose(τ_rnea, τ)\n",
"assert jnp.allclose(f_B, link_forces[0])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gYZ1jK1Neg1H",
"outputId": "0de79770-1e18-4027-bb47-5713bc1b4a72"
},
"source": [
"### Centroidal Dynamics\n",
"\n",
"Centroidal dynamics is a useful simplification often employed in planning and control applications. It represents the dynamics projected onto a mixed frame associated with the center of mass (CoM):\n",
"\n",
"$$\n",
"G = G[W] = ({}^W \\mathbf{p}_{\\text{CoM}}, [W])\n",
".\n",
"$$\n",
"\n",
"The governing equations for centroidal dynamics take into account the 6D centroidal momentum:\n",
"\n",
"$$\n",
"{}_G \\mathbf{h} =\n",
"\\begin{bmatrix}\n",
"{}_G \\mathbf{h}^l \\\\ {}_G \\mathbf{h}^\\omega\n",
"\\end{bmatrix} =\n",
"\\begin{bmatrix}\n",
"m \\, {}^W \\dot{\\mathbf{p}}_\\text{CoM} \\\\ {}_G \\mathbf{h}^\\omega\n",
"\\end{bmatrix}\n",
"\\in \\mathbb{R}^6\n",
".\n",
"$$\n",
"\n",
"The equations of centroidal dynamics can be expressed as:\n",
"\n",
"$$\n",
"{}_G \\dot{\\mathbf{h}} =\n",
"m \\,\n",
"\\begin{bmatrix}\n",
"{}^W \\mathbf{g} \\\\ \\mathbf{0}_3\n",
"\\end{bmatrix} +\n",
"\\sum_{C_i \\in \\mathcal{C}} {}_G \\mathbf{X}^{C_i} \\, {}_{C_i} \\mathbf{f}_i\n",
".\n",
"$$\n",
"\n",
"While centroidal dynamics can function independently by considering the total mass $m \\in \\mathbb{R}$ of the robot and the transformations for 6D contact forces ${}_G \\mathbf{X}^{C_i}$ corresponding to the pose ${}^G \\mathbf{H}_{C_i} \\in \\text{SE}(3)$ of the contact frames, advanced kino-dynamic methods may require a relationship between full kinematics and centroidal dynamics. This is typically achieved through the _Centroidal Momentum Matrix_ (also known as the _centroidal momentum Jacobian_):\n",
"\n",
"$$\n",
"{}_G \\mathbf{h} = J_\\text{CMM}(\\mathbf{q}) \\, \\boldsymbol{\\nu}\n",
".\n",
"$$\n",
"\n",
"JaxSim offers APIs to compute all these quantities (and many more) in the `jaxsim.api.com` package."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rrSfxp8lh9YZ"
},
"outputs": [],
"source": [
"# Number of contact points.\n",
"n_cp = len(model.kin_dyn_parameters.contact_parameters.body)\n",
"print(\"Number of contact points:\", n_cp, \"\\n\")\n",
"\n",
"# Compute the centroidal momentum.\n",
"J_CMM = js.com.centroidal_momentum_jacobian(model=model, data=data)\n",
"G_h = J_CMM @ ν\n",
"print(f\"G_h: shape={G_h.shape}\")\n",
"print(f\"J_CMM: shape={J_CMM.shape}\")\n",
"\n",
"# The following relation should hold.\n",
"assert jnp.allclose(G_h, js.com.centroidal_momentum(model=model, data=data))\n",
"\n",
"# If we consider all contact points of the model as active\n",
"# (discourages since they might be too many), the 6D transforms of\n",
"# collidable points can be computed as follows:\n",
"W_H_C = js.contact.transforms(model=model, data=data)\n",
"\n",
"# Compute the pose of the G frame.\n",
"W_p_CoM = js.com.com_position(model=model, data=data)\n",
"G_H_W = jaxsim.math.Transform.inverse(jnp.eye(4).at[0:3, 3].set(W_p_CoM))\n",
"\n",
"# Convert from SE(3) to the transforms for 6D forces.\n",
"G_Xf_C = jax.vmap(\n",
" lambda W_H_Ci: jaxsim.math.Adjoint.from_transform(\n",
" transform=G_H_W @ W_H_Ci, inverse=True\n",
" )\n",
")(W_H_C)\n",
"print(f\"G_Xf_C: shape={G_Xf_C.shape}\")\n",
"\n",
"# Let's create random 3D linear forces applied to the contact points.\n",
"C_fl = jax.random.uniform(\n",
" minval=-10.0,\n",
" maxval=10.0,\n",
" shape=(n_cp, 3),\n",
" key=jax.random.PRNGKey(0),\n",
")\n",
"\n",
"# Compute the 3D gravity vector and the total mass of the robot.\n",
"m = js.model.total_mass(model=model)\n",
"\n",
"# The centroidal dynamics can be computed as follows.\n",
"G_ḣ = 0\n",
"G_ḣ += m * jnp.hstack([0, 0, model.gravity, 0, 0, 0])\n",
"G_ḣ += jnp.einsum(\"c66,c6->6\", G_Xf_C, jnp.hstack([C_fl, jnp.zeros_like(C_fl)]))\n",
"print(f\"G_ḣ: shape={G_ḣ.shape}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Ot6HePB_twaE",
"outputId": "02a6abae-257e-45ee-e9de-6a607cdbeb9a"
},
"source": [
"## Contact Frames\n",
"\n",
"Many control and planning applications require projecting the floating-base dynamics into the contact space or computing quantities related to active contact points, such as enforcing holonomic constraints.\n",
"\n",
"The underlying theory for these applications becomes clearer in a mixed representation. Specifically, the position, linear velocity, and linear acceleration of contact points in their corresponding mixed frame align with the numerical derivatives of their coordinate vectors.\n",
"\n",
"Key methodologies in this area may involve the Delassus matrix:\n",
"\n",
"$$\n",
"\\Psi(\\mathbf{q}) = J_{W,C}(\\mathbf{q}) \\, M(\\mathbf{q})^{-1} \\, J_{W,C}^T(\\mathbf{q})\n",
"$$\n",
"\n",
"or the linear acceleration of a contact point:\n",
"\n",
"$$\n",
"{}^W \\ddot{\\mathbf{p}}_C = \\frac{\\text{d} (J^l_{W,C} \\boldsymbol{\\nu})}{\\text{d}t}\n",
"= \\dot{J}^l_{W,C} \\boldsymbol{\\nu} + J^l_{W,C} \\dot{\\boldsymbol{\\nu}}\n",
".\n",
"$$\n",
"\n",
"JaxSim offers APIs to compute all these quantities (and many more) in the `jaxsim.api.contact` package."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LITRC3STliKR"
},
"outputs": [],
"source": [
"with (\n",
" data.switch_velocity_representation(VelRepr.Mixed),\n",
" references.switch_velocity_representation(VelRepr.Mixed),\n",
"):\n",
"\n",
" # Compute the mixed generalized velocity.\n",
" BW_ν = data.generalized_velocity\n",
"\n",
" # Compute the mixed generalized acceleration.\n",
" BW_ν̇ = jnp.hstack(\n",
" js.model.forward_dynamics(\n",
" model=model,\n",
" data=data,\n",
" link_forces=references.link_forces(model=model, data=data),\n",
" joint_forces=references.joint_force_references(model=model),\n",
" )\n",
" )\n",
"\n",
" # Compute the mass matrix in mixed representation.\n",
" BW_M = js.model.free_floating_mass_matrix(model=model, data=data)\n",
"\n",
" # Compute the contact Jacobian and its derivative.\n",
" Jl_WC = js.contact.jacobian(model=model, data=data)[:, 0:3, :]\n",
" J̇l_WC = js.contact.jacobian_derivative(model=model, data=data)[:, 0:3, :]\n",
"\n",
"# Compute the Delassus matrix.\n",
"Ψ = jnp.vstack(Jl_WC) @ jnp.linalg.lstsq(BW_M, jnp.vstack(Jl_WC).T)[0]\n",
"print(f\"Ψ: shape={Ψ.shape}\")\n",
"\n",
"# Compute the transforms of the mixed frames implicitly associated\n",
"# to each collidable point.\n",
"W_H_C = js.contact.transforms(model=model, data=data)\n",
"print(f\"W_H_C: shape={W_H_C.shape}\")\n",
"\n",
"# Compute the linear velocity of the collidable points.\n",
"with data.switch_velocity_representation(VelRepr.Mixed):\n",
" W_ṗ_B = js.contact.collidable_point_velocities(model=model, data=data)[:, 0:3]\n",
" print(f\"W_ṗ_B: shape={W_ṗ_B.shape}\")\n",
"\n",
"# Compute the linear acceleration of the collidable points.\n",
"W_p̈_C = 0\n",
"W_p̈_C += jnp.einsum(\"c3g,g->c3\", J̇l_WC, BW_ν)\n",
"W_p̈_C += jnp.einsum(\"c3g,g->c3\", Jl_WC, BW_ν̇)\n",
"print(f\"W_p̈_C: shape={W_p̈_C.shape}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LITRC3STliKR"
},
"source": [
"## Conclusions\n",
"\n",
"This notebook provided an overview of the main APIs in JaxSim for its use as a multibody dynamics library. Here are a few key points to remember:\n",
"\n",
"- Explore all the modules in the `jaxsim.api` package to discover the full range of APIs available. Many more functionalities exist beyond what was covered in this notebook.\n",
"- All APIs follow a functional approach, consistent with the JAX programming style.\n",
"- This functional design allows for easy application of `jax.vmap` to execute functions in parallel on hardware accelerators.\n",
"- Since the entire multibody dynamics library is built with JAX, it natively supports `jax.grad`, `jax.jacfwd`, and `jax.jacrev` transformations, enabling automatic differentiation through complex logic without additional effort.\n",
"\n",
"Have fun!"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuClass": "premium",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "comodo_jaxsim",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: examples/jaxsim_as_physics_engine.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "H-WgcgGQaTG7"
},
"source": [
"# JaxSim as a hardware-accelerated parallel physics engine\n",
"\n",
"This notebook shows how to use the key APIs to load a robot model and simulate multiple trajectories simultaneously.\n",
"\n",
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_as_physics_engine.ipynb\">\n",
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
"</a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SgOSnrSscEkt"
},
"source": [
"## Prepare the environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fdqvAqMDaTG9"
},
"outputs": [],
"source": [
"# @title Imports and setup\n",
"import os\n",
"import sys\n",
"from IPython.display import clear_output\n",
"\n",
"IS_COLAB = \"google.colab\" in sys.modules\n",
"\n",
"# Install JAX and Gazebo\n",
"if IS_COLAB:\n",
" !{sys.executable} -m pip install --pre -qU jaxsim\n",
" !apt install -qq lsb-release wget gnupg\n",
" !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n",
" !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n",
" !apt -qq update\n",
" !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n",
"\n",
" clear_output()\n",
"\n",
"# Set environment variable to avoid GPU out of memory errors\n",
"%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n",
"\n",
"\n",
"# ================\n",
"# Notebook imports\n",
"# ================\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jaxsim.api as js\n",
"from jaxsim import logging\n",
"import pathlib\n",
"\n",
"logging.set_logging_level(logging.LoggingLevel.WARNING)\n",
"print(f\"Running on {jax.devices()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NqjuZKvOaTG_"
},
"source": [
"## Prepare the simulation\n",
"\n",
"JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. In this example, we will load the [ergoCub][ergocub] model urdf.\n",
"\n",
"[sdformat]: http://sdformat.org/\n",
"[urdf]: http://wiki.ros.org/urdf/\n",
"[ergocub]: https://ergocub.eu/\n",
"[rod]: https://github.com/gbionics/rod\n",
"\n",
"### Create the model and its data\n",
" To define a simulation we need two main objects:\n",
"\n",
"- `model`: an object that defines the dynamics of the system.\n",
"- `data`: an object that contains the state of the system.\n",
"\n",
"\n",
"The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\n",
"To see the advanced usage, check the advanced example, where you will see how to pass explicitly an integrator class and state to the `model` object and how to change the contact model."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create the model "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "etQ577cFaTHA"
},
"outputs": [],
"source": [
"# Create the JaxSim model.\n",
"try:\n",
" os.environ[\"ROBOT_DESCRIPTION_COMMIT\"] = \"v0.7.7\"\n",
"\n",
" import robot_descriptions.ergocub_description\n",
"\n",
"finally:\n",
" _ = os.environ.pop(\"ROBOT_DESCRIPTION_COMMIT\", None)\n",
"\n",
"model_description_path = pathlib.Path(\n",
" robot_descriptions.ergocub_description.URDF_PATH.replace(\n",
" \"ergoCubSN002\", \"ergoCubSN001\"\n",
" )\n",
")\n",
"\n",
"clear_output()\n",
"\n",
"full_model = js.model.JaxSimModel.build_from_model_description(\n",
" model_description=model_description_path,\n",
" time_step=0.0001,\n",
" is_urdf=True\n",
")\n",
"\n",
"joints_list = tuple(('l_shoulder_pitch', 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow',\n",
" 'r_shoulder_pitch', 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow',\n",
" 'l_hip_pitch', 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',\n",
" 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll'))\n",
"\n",
"model = js.model.reduce(\n",
" model=full_model,\n",
" considered_joints=joints_list\n",
")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create the data object \n",
"\n",
"The data object is never changed by reference. Anytime you call a method aimed at modifying data, like `reset_base_position`, a new data object will be returned with the updated attributes while the original data will not be changed."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create the data of a single model.\n",
"data = js.data.JaxSimModelData.build(model=model, base_position=jnp.array([0.0, 0.0, 1.0]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Simulation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a random JAX key.\n",
"\n",
"key = jax.random.PRNGKey(seed=0)\n",
"\n",
"# Initialize the simulated time.\n",
"T = jnp.arange(start=0, stop=0.3, step=model.time_step)\n",
"\n",
"# Simulate\n",
"for _t in T:\n",
" data = js.model.step(\n",
" model=model,\n",
" data=data,\n",
" link_forces=None,\n",
" joint_force_references=None,\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Vectorized simulation \n",
"\n",
"We will now vectorize the simulation on batched data using `jax.vmap`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# first we have to vmap the function\n",
"\n",
"import functools\n",
"from typing import Any\n",
"\n",
"\n",
"@jax.jit\n",
"def step_single(\n",
" model: js.model.JaxSimModel,\n",
" data: js.data.JaxSimModelData,\n",
") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n",
"\n",
" # Close step over static arguments.\n",
" return js.model.step(\n",
" model=model,\n",
" data=data,\n",
" link_forces=None,\n",
" joint_force_references=None,\n",
" )\n",
"\n",
"\n",
"@jax.jit\n",
"@functools.partial(jax.vmap, in_axes=(None, 0))\n",
"def step_parallel(\n",
" model: js.model.JaxSimModel,\n",
" data: js.data.JaxSimModelData,\n",
") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n",
"\n",
" return step_single(\n",
" model=model, data=data\n",
" )\n",
"\n",
"\n",
"# Then we have to create the vector of initial state\n",
"batch_size = 5\n",
"data_batch_t0 = jax.vmap(\n",
" lambda pos: js.data.JaxSimModelData.build(model=model, base_position=pos))(jnp.tile(jnp.array([0.0, 0.0, 1.0]), (batch_size, 1)))\n",
"\n",
"data = data_batch_t0\n",
"for _t in T:\n",
" data = step_parallel(model, data)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuClass": "premium",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "jaxsim",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.1"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: examples/jaxsim_as_physics_engine_advanced.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "H-WgcgGQaTG7"
},
"source": [
"# JaxSim as a hardware-accelerated parallel physics engine-advanced usage\n",
"\n",
"JaxSim is developed to optimize synthetic data generation by sampling trajectories using hardware accelerators such as GPUs and TPUs.\n",
"\n",
"In this notebook, you'll learn how to use the key APIs to load a simple robot model (a sphere) and simulate multiple trajectories in parallel on GPUs.\n",
"\n",
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_as_physics_engine_advanced.ipynb\">\n",
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
"</a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SgOSnrSscEkt"
},
"source": [
"## Prepare the environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fdqvAqMDaTG9"
},
"outputs": [],
"source": [
"# @title Imports and setup\n",
"import sys\n",
"from IPython.display import clear_output\n",
"\n",
"IS_COLAB = \"google.colab\" in sys.modules\n",
"\n",
"# Install JAX and Gazebo\n",
"if IS_COLAB:\n",
" !{sys.executable} -m pip install --pre -qU jaxsim[viz]\n",
" !apt install -qq lsb-release wget gnupg\n",
" !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n",
" !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n",
" !apt -qq update\n",
" !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n",
"\n",
" clear_output()\n",
"\n",
"# Set environment variable to avoid GPU out of memory errors\n",
"%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n",
"\n",
"# ================\n",
"# Notebook imports\n",
"# ================\n",
"\n",
"import os\n",
"\n",
"if sys.platform == 'darwin':\n",
" os.environ[\"MUJOCO_GL\"] = \"glfw\"\n",
"else:\n",
" os.environ[\"MUJOCO_GL\"] = \"egl\"\n",
"\n",
"import jax\n",
"\n",
"import jax.numpy as jnp\n",
"import jaxsim.api as js\n",
"import rod\n",
"from jaxsim import logging\n",
"from rod.builder.primitives import SphereBuilder\n",
"\n",
"logging.set_logging_level(logging.LoggingLevel.WARNING)\n",
"print(f\"Running on {jax.devices()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QtCCUhdpdGFH"
},
"source": [
"## Prepare the simulation\n",
"\n",
"JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. This is done using the [`gbionics/rod`][rod] library, which processes these formats.\n",
"\n",
"The `rod` library also allows creating in-memory models that can be serialized to SDF or URDF. We'll use this functionality to build a sphere model, which will later be used to create the JaxSim model.\n",
"\n",
"[sdformat]: http://sdformat.org/\n",
"[urdf]: http://wiki.ros.org/urdf/\n",
"[rod]: https://github.com/gbionics/rod"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "0emoMQhCaTG_"
},
"outputs": [],
"source": [
"# @title Create the model description of a sphere\n",
"\n",
"# Create a SDF model.\n",
"# The builder takes care to compute the right inertia tensor for you.\n",
"rod_sdf = rod.Sdf(\n",
" version=\"1.7\",\n",
" model=SphereBuilder(radius=0.10, mass=1.0, name=\"sphere\")\n",
" .build_model()\n",
" .add_link()\n",
" .add_inertial()\n",
" .add_visual()\n",
" .add_collision()\n",
" .build(),\n",
")\n",
"\n",
"# Rod allows to update the frames w.r.t. the poses are expressed.\n",
"rod_sdf.model.switch_frame_convention(\n",
" frame_convention=rod.FrameConvention.Urdf, explicit_frames=True\n",
")\n",
"\n",
"# Serialize the model to a SDF string.\n",
"model_sdf_string = rod_sdf.serialize(pretty=True)\n",
"print(model_sdf_string)\n",
"\n",
"# JaxSim currently only supports collisions between points attached to bodies\n",
"# and a ground surface modeled as a heightmap sampled from a smooth function.\n",
"# While this approach is universal as it applies to generic meshes, the number\n",
"# of considered points greatly affects the performance. Spheres, by default,\n",
"# are discretized with 250 points. It's too much for this simple example.\n",
"# This number can be decreased with the following environment variable.\n",
"os.environ[\"JAXSIM_COLLISION_SPHERE_POINTS\"] = \"50\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NqjuZKvOaTG_"
},
"source": [
"### Create the model and its data\n",
"\n",
"JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\n",
"\n",
"- `model`: an object that defines the dynamics of the system.\n",
"- `data`: an object that contains the state of the system.\n",
"- `integrator` *(Optional)*: an object that defines the integration method.\n",
"- `integrator_metadata` *(Optional)*: an object that contains the state of the integrator.\n",
"\n",
"The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\n",
"In this example, we will explicitly pass an integrator class to the `model` object and we will use the default `SoftContacts` contact model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "etQ577cFaTHA"
},
"outputs": [],
"source": [
"# Create the JaxSim model.\n",
"# This is shared among all the parallel instances.\n",
"model = js.model.JaxSimModel.build_from_model_description(\n",
" model_description=model_sdf_string,\n",
" time_step=0.001,\n",
")\n",
"\n",
"# Create the data of a single model.\n",
"# We will create a vectorized instance later.\n",
"data_single = js.data.JaxSimModelData.zero(model=model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "o86Teq5piVGj"
},
"outputs": [],
"source": [
"# Initialize the simulated time.\n",
"T = jnp.arange(start=0, stop=1.0, step=model.time_step)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "V6IeD2B3m4F0"
},
"source": [
"## Sample a batch of trajectories in parallel\n",
"\n",
"With the provided resources, you can step through an open-loop trajectory on a single model using `jaxsim.api.model.step`.\n",
"\n",
"In this notebook, we'll focus on running parallel steps. We'll use JAX's automatic vectorization to apply the step function to batched data.\n",
"\n",
"Note that these parallel simulations are independent — models don't interact, so there's no need to avoid initial collisions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vtEn0aIzr_2j"
},
"outputs": [],
"source": [
"# @title Generate batched initial data\n",
"\n",
"# Create a random JAX key.\n",
"key = jax.random.PRNGKey(seed=0)\n",
"\n",
"# Split subkeys for sampling random initial data.\n",
"batch_size = 9\n",
"row_length = int(jnp.sqrt(batch_size))\n",
"row_dist = 0.3 * row_length\n",
"key, *subkeys = jax.random.split(key=key, num=batch_size + 1)\n",
"\n",
"# Create the batched data by sampling the height from [0.5, 0.6] meters.\n",
"data_batch_t0 = jax.vmap(\n",
" lambda key: js.data.random_model_data(\n",
" model=model,\n",
" key=key,\n",
" base_pos_bounds=([0, 0, 0.3], [0, 0, 1.2]),\n",
" base_vel_lin_bounds=(0, 0),\n",
" base_vel_ang_bounds=(0, 0),\n",
" )\n",
")(jnp.vstack(subkeys))\n",
"\n",
"x, y = jnp.meshgrid(\n",
" jnp.linspace(-row_dist, row_dist, num=row_length),\n",
" jnp.linspace(-row_dist, row_dist, num=row_length),\n",
")\n",
"xy_coordinate = jnp.stack([x.flatten(), y.flatten()], axis=-1)\n",
"\n",
"# Reset the x and y position to a grid.\n",
"data_batch_t0 = data_batch_t0.replace(\n",
" model=model,\n",
" base_position=data_batch_t0.base_position.at[:, :2].set(xy_coordinate),\n",
")\n",
"\n",
"print(\"W_p_B(t0)=\\n\", data_batch_t0.base_position[0:10])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0tQPfsl6uxHm"
},
"outputs": [],
"source": [
"# @title Create parallel step function\n",
"\n",
"import functools\n",
"from typing import Any\n",
"\n",
"\n",
"@jax.jit\n",
"def step_single(\n",
" model: js.model.JaxSimModel,\n",
" data: js.data.JaxSimModelData,\n",
") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n",
"\n",
" # Close step over static arguments.\n",
" return js.model.step(\n",
" model=model,\n",
" data=data,\n",
" link_forces=None,\n",
" joint_force_references=None,\n",
" )\n",
"\n",
"\n",
"@jax.jit\n",
"@functools.partial(jax.vmap, in_axes=(None, 0))\n",
"def step_parallel(\n",
" model: js.model.JaxSimModel,\n",
" data: js.data.JaxSimModelData,\n",
") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n",
"\n",
" return step_single(\n",
" model=model, data=data\n",
" )\n",
"\n",
"\n",
"# The first run will be slow since JAX needs to JIT-compile the functions.\n",
"_ = step_single(model, data_single)\n",
"_ = step_parallel(model, data_batch_t0)\n",
"\n",
"# Benchmark the execution of a single step.\n",
"print(\"\\nSingle simulation step:\")\n",
"%timeit step_single(model, data_single)\n",
"\n",
"# On hardware accelerators, there's a range of batch_size values where\n",
"# increasing the number of parallel instances doesn't affect computation time.\n",
"# This range depends on the GPU/TPU specifications.\n",
"print(f\"\\nParallel simulation steps (batch_size={batch_size} on {jax.devices()[0]}):\")\n",
"%timeit step_parallel(model, data_batch_t0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VNwzT2JQ1n15"
},
"outputs": [],
"source": [
"# @title Run parallel simulation\n",
"\n",
"data = data_batch_t0\n",
"data_trajectory_list = []\n",
"\n",
"for _ in T:\n",
"\n",
" data = step_parallel(model, data)\n",
" data_trajectory_list.append(data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y6n720Cr3G44"
},
"source": [
"## Visualize trajectory"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BLPODyKr3Lyg"
},
"outputs": [],
"source": [
"# Convert a list of PyTrees to a batched PyTree.\n",
"# This operation is called 'tree transpose' in JAX.\n",
"data_trajectory = jax.tree.map(lambda *leafs: jnp.stack(leafs), *data_trajectory_list)\n",
"\n",
"print(f\"W_p_B: shape={data_trajectory.base_position.shape}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-jxJXy5r3RMt"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"plt.plot(T, data_trajectory.base_position[:, :, 2])\n",
"plt.grid(True)\n",
"plt.xlabel(\"Time [s]\")\n",
"plt.ylabel(\"Height [m]\")\n",
"plt.title(\"Height trajectory of the sphere\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import jaxsim.mujoco\n",
"\n",
"mjcf_string, assets = jaxsim.mujoco.ModelToMjcf.convert(\n",
" model.built_from,\n",
" cameras=jaxsim.mujoco.loaders.MujocoCamera.build_from_target_view(\n",
" camera_name=\"sphere_cam\",\n",
" lookat=[0, 0, 0.3],\n",
" distance=4,\n",
" azimuth=150,\n",
" elevation=-10,\n",
" ),\n",
")\n",
"\n",
"# Create a helper for each parallel instance.\n",
"mj_model_helpers = [\n",
" jaxsim.mujoco.MujocoModelHelper.build_from_xml(\n",
" mjcf_description=mjcf_string, assets=assets\n",
" )\n",
" for _ in range(batch_size)\n",
"]\n",
"\n",
"# Create the video recorder.\n",
"recorder = jaxsim.mujoco.MujocoVideoRecorder(\n",
" model=mj_model_helpers[0].model,\n",
" data=[helper.data for helper in mj_model_helpers],\n",
" fps=int(1 / model.time_step),\n",
" width=320 * 2,\n",
" height=240 * 2,\n",
")\n",
"\n",
"for data_t in data_trajectory_list:\n",
"\n",
" for helper, base_position, base_quaternion, joint_position in zip(\n",
" mj_model_helpers,\n",
" data_t.base_position,\n",
" data_t.base_orientation,\n",
" data_t.joint_positions,\n",
" strict=True,\n",
" ):\n",
" helper.set_base_position(position=base_position)\n",
" helper.set_base_orientation(orientation=base_quaternion)\n",
"\n",
" if model.dofs() > 0:\n",
" helper.set_joint_positions(\n",
" positions=joint_position, joint_names=model.joint_names()\n",
" )\n",
"\n",
" # Record a new video frame.\n",
" recorder.record_frame(camera_name=\"sphere_cam\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import mediapy as media\n",
"\n",
"media.show_video(recorder.frames, fps=recorder.fps)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuClass": "premium",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "jaxpypi",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.1"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: examples/jaxsim_for_robot_controllers.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "EhPy6FgiZH4d"
},
"source": [
"# JaxSim for developing closed-loop robot controllers\n",
"\n",
"Originally developed as a **hardware-accelerated physics engine**, JaxSim has expanded its capabilities to become a full-featured **JAX-based multibody dynamics library**.\n",
"\n",
"In this notebook, you'll explore how to combine these two core features. Specifically, you'll learn how to load a robot model and design a model-based controller for closed-loop simulations.\n",
"\n",
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_for_robot_controllers.ipynb\">\n",
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
"</a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vsf1AlxdZH4f"
},
"outputs": [],
"source": [
"# @title Prepare the environment\n",
"from IPython.display import clear_output\n",
"import sys\n",
"\n",
"IS_COLAB = \"google.colab\" in sys.modules\n",
"\n",
"# Install JAX, sdformat, and other notebook dependencies.\n",
"if IS_COLAB:\n",
" !{sys.executable} -m pip install --pre -qU jaxsim[viz]\n",
" !apt install -qq lsb-release wget gnupg\n",
" !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n",
" !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n",
" !apt -qq update\n",
" !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n",
"\n",
" clear_output()\n",
"\n",
"# ================\n",
"# Notebook imports\n",
"# ================\n",
"\n",
"import os\n",
"\n",
"if sys.platform == 'darwin':\n",
" os.environ[\"MUJOCO_GL\"] = \"glfw\"\n",
"else:\n",
" os.environ[\"MUJOCO_GL\"] = \"egl\"\n",
"\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jaxsim.mujoco\n",
"from jaxsim import logging\n",
"\n",
"logging.set_logging_level(logging.LoggingLevel.WARNING)\n",
"print(f\"Running on {jax.devices()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kN-b9nOsZH4g"
},
"source": [
"We will use a simple cartpole model for this example. The cartpole model is a 2D model with a cart that can move horizontally and a pole that can rotate around the cart. The state of the cartpole is given by the position of the cart, the angle of the pole, the velocity of the cart, and the angular velocity of the pole. The control input is the horizontal force applied to the cart."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5aLqrZDqR5LA"
},
"source": [
"## Prepare the simulation\n",
"\n",
"JaxSim supports loading robot models from both [SDF][sdformat] and [URDF][urdf] files, utilizing the [`gbionics/rod`][rod] library for processing these formats.\n",
"\n",
"The `rod` library library can read URDF files and validates them internally using [`gazebosim/sdformat`][sdformat_github]. In this example, we'll load a cart-pole model, which will be used to create the JaxSim simulation model.\n",
"\n",
"[sdformat]: http://sdformat.org/\n",
"[urdf]: http://wiki.ros.org/urdf/\n",
"[rod]: https://github.com/gbionics/rod\n",
"[sdformat_github]: https://github.com/gazebosim/sdformat"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.path.abspath(\"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PZM7hEvFZH4h"
},
"outputs": [],
"source": [
"# @title Load the URDF model\n",
"import pathlib\n",
"import urllib\n",
"\n",
"# Retrieve the file\n",
"url = \"https://raw.githubusercontent.com/gbionics/jaxsim/refs/heads/main/examples/assets/cartpole.urdf\"\n",
"model_path, _ = urllib.request.urlretrieve(url)\n",
"model_urdf_string = pathlib.Path(model_path).read_text()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "M5XsKehvZH4j"
},
"outputs": [],
"source": [
"# @title Create the model and its data\n",
"\n",
"import jaxsim.api as js\n",
"\n",
"# Create the model from the model description.\n",
"model = js.model.JaxSimModel.build_from_model_description(\n",
" model_description=model_urdf_string,\n",
" time_step=0.010,\n",
")\n",
"\n",
"# Create the data storing the simulation state.\n",
"data_zero = js.data.JaxSimModelData.zero(model=model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jk9csR5ETgn1"
},
"outputs": [],
"source": [
"# @title Define simulation parameters\n",
"\n",
"# Initialize the simulated time.\n",
"T = jnp.arange(start=0, stop=5.0, step=model.time_step)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bo6Ke5nAWL-S"
},
"source": [
"## Prepare the MuJoCo renderer\n",
"\n",
"For visualization purpose, we use the passive viewer of the MuJoCo simulator. It allows to either open an interactive windows when used locally or record a video when used in notebooks."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "j1_I2i5TZH4n"
},
"outputs": [],
"source": [
"# Create the MJCF resources from the URDF.\n",
"mjcf_string, assets = jaxsim.mujoco.UrdfToMjcf.convert(\n",
" urdf=model.built_from,\n",
" # Create the camera used by the recorder.\n",
" cameras=jaxsim.mujoco.loaders.MujocoCamera.build_from_target_view(\n",
" camera_name=\"cartpole_camera\",\n",
" lookat=js.link.com_position(\n",
" model=model,\n",
" data=data_zero,\n",
" link_index=js.link.name_to_idx(model=model, link_name=\"cart\"),\n",
" in_link_frame=False,\n",
" ),\n",
" distance=3,\n",
" azimuth=150,\n",
" elevation=-10,\n",
" ),\n",
")\n",
"\n",
"# Create a helper to operate on the MuJoCo model and data.\n",
"mj_model_helper = jaxsim.mujoco.MujocoModelHelper.build_from_xml(\n",
" mjcf_description=mjcf_string, assets=assets\n",
")\n",
"\n",
"# Create the video recorder.\n",
"recorder = jaxsim.mujoco.MujocoVideoRecorder(\n",
" model=mj_model_helper.model,\n",
" data=mj_model_helper.data,\n",
" fps=int(1 / model.time_step),\n",
" width=320 * 2,\n",
" height=240 * 2,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DpRvvGujZH4o"
},
"source": [
"## Open-loop simulation\n",
"\n",
"Now, let's run a simulation to demonstrate the open-loop dynamics of the system."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gSWzcsKWZH4p"
},
"outputs": [],
"source": [
"import mediapy as media\n",
"\n",
"\n",
"# Create a random joint position.\n",
"# For a random full state, you can use jaxsim.api.data.random_model_data.\n",
"random_joint_positions = jax.random.uniform(\n",
" minval=-1.0,\n",
" maxval=1.0,\n",
" shape=(model.dofs(),),\n",
" key=jax.random.PRNGKey(0),\n",
")\n",
"\n",
"# Reset the state to the random joint positions.\n",
"data = js.data.JaxSimModelData.build(model=model, joint_positions=random_joint_positions)\n",
"\n",
"for _ in T:\n",
"\n",
" # Step the JaxSim simulation.\n",
" data = js.model.step(\n",
" model=model,\n",
" data=data,\n",
" joint_force_references=None,\n",
" link_forces=None,\n",
" )\n",
"\n",
" # Update the MuJoCo data.\n",
" mj_model_helper.set_joint_positions(\n",
" positions=data.joint_positions, joint_names=model.joint_names()\n",
" )\n",
"\n",
" # Record a new video frame.\n",
" recorder.record_frame(camera_name=\"cartpole_camera\")\n",
"\n",
"\n",
"# Play the video.\n",
"media.show_video(recorder.frames, fps=recorder.fps)\n",
"recorder.frames = []"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j1rguK3UZH4p"
},
"source": [
"## Closed-loop simulation\n",
"\n",
"Next, let's design a simple computed torque controller. The equations of motion for the cart-pole system are given by:\n",
"\n",
"$$\n",
"M_{ss}(\\mathbf{s}) \\, \\ddot{\\mathbf{s}} + \\mathbf{h}_s(\\mathbf{s}, \\dot{\\mathbf{s}}) = \\boldsymbol{\\tau}\n",
",\n",
"$$\n",
"\n",
"where:\n",
"\n",
"- $\\mathbf{s} \\in \\mathbb{R}^n$ are the joint positions.\n",
"- $\\dot{\\mathbf{s}} \\in \\mathbb{R}^n$ are the joint velocities.\n",
"- $\\ddot{\\mathbf{s}} \\in \\mathbb{R}^n$ are the joint accelerations.\n",
"- $\\boldsymbol{\\tau} \\in \\mathbb{R}^n$ are the joint torques.\n",
"- $M_{ss} \\in \\mathbb{R}^{n \\times n}$ is the mass matrix.\n",
"- $\\mathbf{h}_s \\in \\mathbb{R}^n$ is the vector of bias forces.\n",
"\n",
"JaxSim computes these quantities for floating-base systems, so we specifically focus on the joint-related portions by marking them with subscripts.\n",
"\n",
"Since no external forces or joint friction are present, we can extend a PD controller with a feed-forward term that includes gravity compensation:\n",
"\n",
"$$\n",
"\\begin{cases}\n",
"\\boldsymbol{\\tau} &= M_{ss} \\, \\ddot{\\mathbf{s}}^* + \\mathbf{h}_s \\\\\n",
"\\ddot{\\mathbf{s}}^* &= \\ddot{\\mathbf{s}}^\\text{des} - k_p(\\mathbf{s} - \\mathbf{s}^{\\text{des}}) - k_d(\\mathbf{s}^{\\text{des}} - \\dot{\\mathbf{s}}^{\\text{des}})\n",
"\\end{cases}\n",
"\\quad\n",
",\n",
"$$\n",
"\n",
"where $\\tilde{\\mathbf{s}} = \\left(\\mathbf{s} - \\mathbf{s}^\\text{des}\\right)$ is the joint position error.\n",
"\n",
"With this control law, the closed-loop system dynamics simplifies to:\n",
"\n",
"$$\n",
"\\ddot{\\tilde{\\mathbf{s}}} = -k_p \\tilde{\\mathbf{s}} - k_d \\dot{\\tilde{\\mathbf{s}}}\n",
",\n",
"$$\n",
"\n",
"which converges asymptotically to zero, ensuring stability."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rfMTCMyGZH4q"
},
"outputs": [],
"source": [
"# @title Create the computed torque controller\n",
"\n",
"# Define the PD gains\n",
"kp = 10.0\n",
"kd = 6.0\n",
"\n",
"\n",
"def computed_torque_controller(\n",
" data: js.data.JaxSimModelData,\n",
" s_des: jax.Array,\n",
" s_dot_des: jax.Array,\n",
") -> jax.Array:\n",
"\n",
" # Compute the gravity compensation term.\n",
" hs = js.model.free_floating_bias_forces(model=model, data=data)[6:]\n",
"\n",
" # Compute the joint-related portion of the floating-base mass matrix.\n",
" Mss = js.model.free_floating_mass_matrix(model=model, data=data)[6:, 6:]\n",
"\n",
" # Get the current joint positions and velocities.\n",
" s = data.joint_positions\n",
" ṡ = data.joint_velocities\n",
"\n",
" # Compute the actuated joint torques.\n",
" s_star = -kp * (s - s_des) - kd * (ṡ - s_dot_des)\n",
" τ = Mss @ s_star + hs\n",
"\n",
" return τ"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ERAUisywZH4q"
},
"source": [
"Now, we can use the `pd_controller` function to compute the torque to apply to the cartpole. Our aim is to stabilize the cartpole in the upright position, so we set the desired position `q_d` to 0 and the desired velocity `q_dot_d` to 0."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8YmDdGDVZH4q"
},
"outputs": [],
"source": [
"# @title Run the simulation\n",
"\n",
"# Initialize the data.\n",
"\n",
"# Set the joint positions.\n",
"data = js.data.JaxSimModelData.build(model=model, joint_positions=jnp.array([-0.25, jnp.deg2rad(160)]), joint_velocities=jnp.array([3.00, jnp.deg2rad(10) / model.time_step]))\n",
"\n",
"for _ in T:\n",
"\n",
" # Get the actuated torques from the computed torque controller.\n",
" τ = computed_torque_controller(\n",
" data=data,\n",
" s_des=jnp.array([0.0, 0.0]),\n",
" s_dot_des=jnp.array([0.0, 0.0]),\n",
" )\n",
"\n",
" # Step the JaxSim simulation.\n",
" data = js.model.step(\n",
" model=model,\n",
" data=data,\n",
" joint_force_references=τ,\n",
" )\n",
"\n",
" # Update the MuJoCo data.\n",
" mj_model_helper.set_joint_positions(\n",
" positions=data.joint_positions, joint_names=model.joint_names()\n",
" )\n",
"\n",
" # Record a new video frame.\n",
" recorder.record_frame(camera_name=\"cartpole_camera\")\n",
"\n",
"media.show_video(recorder.frames, fps=recorder.fps)\n",
"recorder.frames = []"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sZ76QqeWeMQz"
},
"source": [
"## Conclusions\n",
"\n",
"In this notebook, we explored how to use JaxSim for developing a closed-loop controller for a robot model. Key takeaways include:\n",
"\n",
"- We performed an open-loop simulation to understand the dynamics of the system without control.\n",
"- We implemented a computed torque controller with PD feedback and a feed-forward gravity compensation term, enabling the stabilization of the system by controlling joint torques.\n",
"- The closed-loop simulation can leverage hardware acceleration on GPUs and TPUs, with the ability to use `jax.vmap` for parallel sampling through automatic vectorization.\n",
"\n",
"JaxSim's closed-loop support can be extended to more advanced, model-based reactive controllers and planners for trajectory optimization. To explore optimization-based methods, consider the following JAX-based projects for hardware-accelerated control and planning:\n",
"\n",
"- [`deepmind/optax`](https://github.com/google-deepmind/optax)\n",
"- [`google/jaxopt`](https://github.com/google/jaxopt)\n",
"- [`patrick-kidger/lineax`](https://github.com/patrick-kidger/lineax)\n",
"- [`patrick-kidger/optimistix`](https://github.com/patrick-kidger/optimistix)\n",
"- [`kevin-tracy/qpax`](https://github.com/kevin-tracy/qpax)\n",
"\n",
"Additionally, if your controllers or planners require the derivatives of the dynamics with respect to the state or inputs, you can obtain them using automatic differentiation directly through JaxSim's API."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuClass": "premium",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "comodo_jaxsim",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: pyproject.toml
================================================
[project]
name = "jaxsim"
dynamic = ["version"]
requires-python = ">= 3.10"
description = "A differentiable physics engine and multibody dynamics library for control and robot learning."
authors = [
{ name = "Diego Ferigo", email = "dgferigo@gmail.com" },
{ name = "Filippo Luca Ferretti", email = "filippoluca.ferretti@outlook.com" },
]
maintainers = [
{ name = "Filippo Luca Ferretti", email = "filippo.ferretti@outlook.com" },
]
license = "BSD-3-Clause"
license-files = ["LICENSE"]
keywords = [
"physics",
"physics engine",
"jax",
"rigid body dynamics",
"featherstone",
"reinforcement learning",
"robot",
"robotics",
"sdf",
"urdf",
]
classifiers = [
"Development Status :: 4 - Beta",
"Framework :: Robot Framework",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Operating System :: POSIX :: Linux",
"Operating System :: MacOS",
"Operating System :: Microsoft",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Programming Language :: Python :: Implementation :: CPython",
"Topic :: Games/Entertainment :: Simulation",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Physics",
"Topic :: Software Development",
]
dependencies = [
"coloredlogs",
"jax >= 0.4.34",
"jaxlib >= 0.4.34",
"jaxlie >= 1.3.0",
"jax_dataclasses >= 1.4.0",
"pptree",
"optax >= 0.2.3",
"qpax",
"rod >= 0.4.1",
"typing_extensions ; python_version < '3.12'",
"trimesh",
]
[project.optional-dependencies]
testing = [
"chex >= 0.1.91",
"idyntree >= 12.2.1",
"pytest >=6.0",
"pytest-benchmark",
"pytest-icdiff",
"pytest-xdist",
"robot-descriptions >= 1.16.0",
"icub-models",
]
viz = [
"lxml",
"mediapy",
"mujoco >= 3.0.0",
"scipy >= 1.14.0",
]
all = [
"jaxsim[testing,viz]",
]
[project.readme]
file = "README.md"
content-type = "text/markdown"
[project.urls]
Changelog = "https://github.com/gbionics/jaxsim/releases"
Documentation = "https://jaxsim.readthedocs.io"
Source = "https://github.com/gbionics/jaxsim"
Tracker = "https://github.com/gbionics/jaxsim/issues"
# ===========
# Build tools
# ===========
[build-system]
build-backend = "hatchling.build"
requires = [
"hatchling",
"hatch-vcs",
]
[tool.hatch.version]
source = "vcs"
raw-options = { local_scheme = "dirty-tag" }
[tool.hatch.build.targets.wheel]
packages = ["src/jaxsim"]
[tool.hatch.build.hooks.vcs]
version-file = "src/jaxsim/_version.py"
# =================
# Style and testing
# =================
[tool.black]
line-length = 88
[tool.isort]
multi_line_output = 3
profile = "black"
[tool.pytest.ini_options]
addopts = "-rsxX -v --strict-markers --benchmark-skip --benchmark-warmup=ON"
minversion = "6.0"
testpaths = [
"tests",
]
# ==================
# Ruff configuration
# ==================
[tool.ruff]
exclude = [
".git",
".pixi",
".pytest_cache",
".ruff_cache",
".idea",
".vscode",
".devcontainer",
"__pycache__",
]
preview = true
[tool.ruff.lint]
# https://docs.astral.sh/ruff/rules/
select = [
"B",
"D",
"E",
"F",
"I",
"W",
"RUF",
"UP",
"YTT",
]
ignore = [
"B008", # Function call in default argument
"B024", # Abstract base class without abstract methods
"D100", # Missing docstring in public module
"D104", # Missing docstring in public package
"D105", # Missing docstring in magic method
"D200", # One-line docstring should fit on one line with quotes
"D202", # No blank lines allowed after function docstring
"D203", # Incorrect blank line before class
"D205", # 1 blank line required between summary line and description
"D212", # Multi-line docstring summary should start at the first line
"D411", # Missing blank line before section
"D413", # Missing blank line after last section
"E402", # Module level import not at top of file
"E501", # Line too long
"E731", # Do not assign a `lambda` expression, use a `def`
"E741", # Ambiguous variable name
"I001", # Import block is unsorted or unformatted
"RUF003", # Ambiguous unicode character in comment
]
[tool.ruff.lint.per-file-ignores]
# Ignore `E402` (import violations) in all `__init__.py` files
"**/{tests,docs,tools}/*" = ["E402"]
"**/{tests,examples}/*" = ["B007", "D100", "D102", "D103"]
"__init__.py" = ["F401", "RUF067"]
"docs/conf.py" = ["F401"]
"src/jaxsim/exceptions.py" = ["D401"]
"src/jaxsim/logging.py" = ["D101", "D103"]
# ==================
# Pixi configuration
# ==================
[tool.pixi.workspace]
channels = ["conda-forge"]
platforms = ["linux-64", "linux-aarch64", "osx-arm64", "osx-64"]
requires-pixi = ">=0.39.0"
preview = ["pixi-build"]
[tool.pixi.environments]
# We resolve only two groups: cpu and gpu.
# Then, multiple environments can be created from these groups.
default = { features = ["test", "examples"] }
gpu = { features = ["test", "examples", "gpu"] }
# ---------------
# feature.default
# ---------------
# Dependencies from conda-forge.
[tool.pixi.dependencies]
#
# Matching `project.dependencies`.
#
coloredlogs = "*"
jax = "*"
jaxlib = "*"
jaxlie = "*"
jax-dataclasses = "*"
pptree = "*"
optax = "*"
qpax = "*"
rod = ">=0.4.1"
trimesh = "*"
typing_extensions = "*"
#
# Optional dependencies.
#
lxml = "*"
mediapy = "*"
mujoco = "*"
scipy = "*"
#
# Additional dependencies.
#
pip = "*"
hatchling = "*"
hatch-vcs = "*"
# Dependencies from PyPI.
[tool.pixi.pypi-dependencies]
jaxsim = { path = "./", editable = true }
[tool.pixi.pypi-options]
no-build-isolation = ["jaxsim"]
# ------------
# feature.test
# ------------
[tool.pixi.feature.test.tasks]
pipcheck = "pip check"
benchmark = { cmd = "pytest --benchmark-only --benchmark-warmup=ON", depends-on = ["pipcheck"] }
test = { cmd = "pytest", depends-on = ["pipcheck"] }
[tool.pixi.feature.test.dependencies]
black-jupyter = "*"
chex = ">=0.1.91"
idyntree = "*"
isort = "*"
pre-commit = "*"
pytest = "*"
pytest-benchmark = "*"
pytest-icdiff = "*"
pytest-xdist = "*"
robot_descriptions = ">=1.16.0"
# ----------------
# feature.examples
# ----------------
[tool.pixi.feature.examples.tasks]
examples = { cmd = "jupyter notebook ./examples" }
[tool.pixi.feature.examples.dependencies]
notebook = "*"
robot_descriptions = ">=1.16.0"
# -----------
# feature.gpu
# -----------
[tool.pixi.feature.gpu]
platforms = ["linux-64"]
system-requirements = { cuda = "13" }
[tool.pixi.feature.gpu.dependencies]
jaxlib = { version = "*", build = "*cuda*" }
[tool.pixi.feature.gpu.tasks]
test-gpu = { cmd = "pytest --gpu-only", depends-on = ["pipcheck"] }
================================================
FILE: src/jaxsim/__init__.py
================================================
from . import logging
from ._version import __version__
# Follow upstream development in https://github.com/google/jax/pull/13304
def _jnp_options() -> None:
import os
import jax
# Check if running on TPU.
is_tpu = jax.devices()[0].platform == "tpu"
# Check if running on Metal.
is_metal = jax.devices()[0].platform == "METAL"
# Enable by default 64-bit precision to get accurate physics.
# Users can enforce 32-bit precision by setting the following variable to 0.
use_x64 = os.environ.get("JAX_ENABLE_X64", "1") != "0"
# Notify the user if unsupported 64-bit precision was enforced on TPU.
if (is_tpu or is_metal) and use_x64:
msg = f"64-bit precision is not allowed on {jax.devices()[0].platform.upper}. Enforcing 32bit precision."
logging.warning(msg)
use_x64 = False
if is_metal:
logging.warning(
"JAX Metal backend is experimental. Some functionalities may not be available."
)
# Enable 64-bit precision in JAX.
if use_x64:
logging.info("Enabling JAX to use 64-bit precision")
jax.config.update("jax_enable_x64", True)
# Warn about experimental usage of 32-bit precision.
else:
logging.warning(
"Using 32-bit precision in JaxSim is still experimental, please avoid to use variable step integrators."
)
def _np_options() -> None:
import numpy as np
np.set_printoptions(precision=5, suppress=True, linewidth=150, threshold=10_000)
def _is_editable() -> bool:
import importlib.util
import pathlib
import site
# Get the ModuleSpec of jaxsim.
jaxsim_spec = importlib.util.find_spec(name="jaxsim")
# This can be None. If it's None, assume non-editable installation.
if jaxsim_spec.origin is None:
return False
# Get the folder containing the jaxsim package.
jaxsim_package_dir = str(pathlib.Path(jaxsim_spec.origin).parent.parent)
# The installation is editable if the package dir is not in any {site|dist}-packages.
return jaxsim_package_dir not in site.getsitepackages()
def _get_default_logging_level() -> logging.LoggingLevel:
"""
Get the default logging level.
Returns:
The logging level to set.
"""
import os
import sys
# Allow to override the default logging level with an environment variable.
if overriden_logging_level := os.environ.get("JAXSIM_LOGGING_LEVEL"):
try:
return logging.LoggingLevel[overriden_logging_level.upper()]
except KeyError as exc:
msg = "Invalid logging level defined in JAXSIM_LOGGING_LEVEL"
raise RuntimeError(msg) from exc
# If running under a debugger, set the logging level to DEBUG.
if getattr(sys, "gettrace", lambda: None)():
return logging.LoggingLevel.DEBUG
# If not running under a debugger, set the logging level to INFO or WARNING.
# INFO for editable installations, WARNING for non-editable installations.
# This is to avoid too verbose logging in non-editable installations.
return (
logging.LoggingLevel.INFO
if _is_editable() # noqa: F821
else logging.LoggingLevel.WARNING
)
# Configure the logger with the default logging level.
logging.configure(level=_get_default_logging_level())
# Configure JAX.
_jnp_options()
# Initialize the numpy print options.
_np_options()
del _jnp_options
del _np_options
del _get_default_logging_level
del _is_editable
from . import terrain # isort:skip
from . import api, logging, math, rbda
from .api.common import VelRepr
================================================
FILE: src/jaxsim/api/__init__.py
================================================
from . import common # isort:skip
from . import model, data # isort:skip
from . import (
actuation_model,
com,
contact,
frame,
integrators,
joint,
kin_dyn_parameters,
link,
ode,
references,
)
================================================
FILE: src/jaxsim/api/actuation_model.py
================================================
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.typing as jtp
def compute_resultant_torques(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
joint_force_references: jtp.Vector | None = None,
) -> jtp.Vector:
"""
Compute the resultant torques acting on the joints.
Args:
model: The model to consider.
data: The data of the considered model.
joint_force_references: The joint force references to apply.
Returns:
The resultant torques acting on the joints.
"""
# Build joint torques if not provided.
τ_references = (
jnp.atleast_1d(joint_force_references.squeeze())
if joint_force_references is not None
else jnp.zeros_like(data.joint_positions)
).astype(float)
# ====================
# Enforce joint limits
# ====================
τ_position_limit = jnp.zeros_like(τ_references).astype(float)
if model.dofs() > 0:
# Stiffness and damper parameters for the joint position limits.
k_j = jnp.array(
model.kin_dyn_parameters.joint_parameters.position_limit_spring
).astype(float)
d_j = jnp.array(
model.kin_dyn_parameters.joint_parameters.position_limit_damper
).astype(float)
# Compute the joint position limit violations.
lower_violation = jnp.clip(
data.joint_positions
- model.kin_dyn_parameters.joint_parameters.position_limits_min,
max=0.0,
)
upper_violation = jnp.clip(
data.joint_positions
- model.kin_dyn_parameters.joint_parameters.position_limits_max,
min=0.0,
)
# Compute the joint position limit torque.
τ_position_limit -= jnp.diag(k_j) @ (lower_violation + upper_violation)
τ_position_limit -= (
jnp.positive(τ_position_limit) * jnp.diag(d_j) @ data.joint_velocities
)
# ====================
# Joint friction model
# ====================
τ_friction = jnp.zeros_like(τ_references).astype(float)
# Apply joint friction only if enabled in the actuation parameters.
if model.dofs() > 0 and model.actuation_params.enable_friction:
# Static and viscous joint friction parameters
kc = jnp.array(
model.kin_dyn_parameters.joint_parameters.friction_static
).astype(float)
kv = jnp.array(
model.kin_dyn_parameters.joint_parameters.friction_viscous
).astype(float)
# Compute the joint friction torque.
τ_friction = -(
jnp.diag(kc) @ jnp.sign(data.joint_velocities)
+ jnp.diag(kv) @ data.joint_velocities
)
# ===============================
# Compute the total joint forces.
# ===============================
τ_total = τ_references + τ_friction + τ_position_limit
τ_lim = tn_curve_fn(model=model, data=data)
τ_total = jnp.clip(τ_total, -τ_lim, τ_lim)
return τ_total
def tn_curve_fn(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
"""
Compute the torque limits using the tn curve.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The torque limits.
"""
τ_max = model.actuation_params.torque_max # Max torque (Nm)
ω_th = model.actuation_params.omega_th # Threshold speed (rad/s)
ω_max = model.actuation_params.omega_max # Max speed for torque drop-off (rad/s)
abs_vel = jnp.abs(data.joint_velocities)
τ_lim = jnp.where(
abs_vel <= ω_th,
τ_max,
jnp.where(
abs_vel <= ω_max, τ_max * (1 - (abs_vel - ω_th) / (ω_max - ω_th)), 0.0
),
)
return τ_lim
================================================
FILE: src/jaxsim/api/com.py
================================================
import jax
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.math
import jaxsim.typing as jtp
from .common import VelRepr
@jax.jit
@js.common.named_scope
def com_position(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
"""
Compute the position of the center of mass of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The position of the center of mass of the model w.r.t. the world frame.
"""
m = js.model.total_mass(model=model)
W_H_L = data._link_transforms
W_H_B = data._base_transform
B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B)
def B_p̃_LCoM(i) -> jtp.Vector:
m = js.link.mass(model=model, link_index=i)
L_p_LCoM = js.link.com_position(
model=model, data=data, link_index=i, in_link_frame=True
)
return m * B_H_W @ W_H_L[i] @ jnp.hstack([L_p_LCoM, 1])
com_links = jax.vmap(B_p̃_LCoM)(jnp.arange(model.number_of_links()))
B_p̃_CoM = (1 / m) * com_links.sum(axis=0)
B_p̃_CoM = B_p̃_CoM.at[3].set(1)
return (W_H_B @ B_p̃_CoM)[0:3].astype(float)
@jax.jit
@js.common.named_scope
def com_linear_velocity(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
r"""
Compute the linear velocity of the center of mass of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The linear velocity of the center of mass of the model in the
active representation.
Note:
The linear velocity of the center of mass is expressed in the mixed frame
:math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the
active velocity representation is either inertial-fixed or mixed,
and :math:`[C] = [B]` if the active velocity representation is body-fixed.
"""
# Extract the linear component of the 6D average centroidal velocity.
# This is expressed in G[B] in body-fixed representation, and in G[W] in
# inertial-fixed or mixed representation.
G_vl_WG = average_centroidal_velocity(model=model, data=data)[0:3]
return G_vl_WG
@jax.jit
@js.common.named_scope
def centroidal_momentum(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
r"""
Compute the centroidal momentum of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The centroidal momentum of the model.
Note:
The centroidal momentum is expressed in the mixed frame
:math:`({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`C = W` if the
active velocity representation is either inertial-fixed or mixed,
and :math:`C = B` if the active velocity representation is body-fixed.
"""
ν = data.generalized_velocity
G_J = centroidal_momentum_jacobian(model=model, data=data)
return G_J @ ν
@jax.jit
@js.common.named_scope
def centroidal_momentum_jacobian(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
r"""
Compute the Jacobian of the centroidal momentum of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The Jacobian of the centroidal momentum of the model.
Note:
The frame corresponding to the output representation of this Jacobian is either
:math:`G[W]`, if the active velocity representation is inertial-fixed or mixed,
or :math:`G[B]`, if the active velocity representation is body-fixed.
Note:
This Jacobian is also known in the literature as Centroidal Momentum Matrix.
"""
# Compute the Jacobian of the total momentum with body-fixed output representation.
# We convert the output representation either to G[W] or G[B] below.
B_Jh = js.model.total_momentum_jacobian(
model=model, data=data, output_vel_repr=VelRepr.Body
)
W_H_B = data._base_transform
B_H_W = jaxsim.math.Transform.inverse(W_H_B)
W_p_CoM = com_position(model=model, data=data)
match data.velocity_representation:
case VelRepr.Inertial | VelRepr.Mixed:
W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841
case VelRepr.Body:
W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841
case _:
raise ValueError(data.velocity_representation)
# Compute the transform for 6D forces.
G_Xf_B = jaxsim.math.Adjoint.from_transform(transform=B_H_W @ W_H_G).T
return G_Xf_B @ B_Jh
@jax.jit
@js.common.named_scope
def locked_centroidal_spatial_inertia(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
):
"""
Compute the locked centroidal spatial inertia of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The locked centroidal spatial inertia of the model.
"""
with data.switch_velocity_representation(VelRepr.Body):
B_Mbb_B = js.model.locked_spatial_inertia(model=model, data=data)
W_H_B = data._base_transform
W_p_CoM = com_position(model=model, data=data)
match data.velocity_representation:
case VelRepr.Inertial | VelRepr.Mixed:
W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841
case VelRepr.Body:
W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841
case _:
raise ValueError(data.velocity_representation)
B_H_G = jaxsim.math.Transform.inverse(W_H_B) @ W_H_G
B_Xv_G = jaxsim.math.Adjoint.from_transform(transform=B_H_G)
G_Xf_B = B_Xv_G.transpose()
return G_Xf_B @ B_Mbb_B @ B_Xv_G
@jax.jit
@js.common.named_scope
def average_centroidal_velocity(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
r"""
Compute the average centroidal velocity of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The average centroidal velocity of the model.
Note:
The average velocity is expressed in the mixed frame
:math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the
active velocity representation is either inertial-fixed or mixed,
and :math:`[C] = [B]` if the active velocity representation is body-fixed.
"""
ν = data.generalized_velocity
G_J = average_centroidal_velocity_jacobian(model=model, data=data)
return G_J @ ν
@jax.jit
@js.common.named_scope
def average_centroidal_velocity_jacobian(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
r"""
Compute the Jacobian of the average centroidal velocity of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The Jacobian of the average centroidal velocity of the model.
Note:
The frame corresponding to the output representation of this Jacobian is either
:math:`G[W]`, if the active velocity representation is inertial-fixed or mixed,
or :math:`G[B]`, if the active velocity representation is body-fixed.
"""
G_J = centroidal_momentum_jacobian(model=model, data=data)
G_Mbb = locked_centroidal_spatial_inertia(model=model, data=data)
return jnp.linalg.inv(G_Mbb) @ G_J
@jax.jit
@js.common.named_scope
def bias_acceleration(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
r"""
Compute the bias linear acceleration of the center of mass.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The bias linear acceleration of the center of mass in the active representation.
Note:
The bias acceleration is expressed in the mixed frame
:math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the
active velocity representation is either inertial-fixed or mixed,
and :math:`[C] = [B]` if the active velocity representation is body-fixed.
"""
# Compute the pose of all links with forward kinematics.
W_H_L = data._link_transforms
# Compute the bias acceleration of all links by zeroing the generalized velocity
# in the active representation.
v̇_bias_WL = js.model.link_bias_accelerations(model=model, data=data)
def other_representation_to_body(
C_v̇_WL: jtp.Vector, C_v_WC: jtp.Vector, L_H_C: jtp.Matrix, L_v_LC: jtp.Vector
) -> jtp.Vector:
"""
Convert the body-fixed representation of the link bias acceleration
C_v̇_WL expressed in a generic frame C to the body-fixed representation L_v̇_WL.
"""
L_X_C = jaxsim.math.Adjoint.from_transform(transform=L_H_C)
C_X_L = jaxsim.math.Adjoint.inverse(L_X_C)
L_v̇_WL = L_X_C @ (C_v̇_WL + jaxsim.math.Cross.vx(C_X_L @ L_v_LC) @ C_v_WC)
return L_v̇_WL
# We need here to get the body-fixed bias acceleration of the links.
# Since it's computed in the active representation, we need to convert it to body.
match data.velocity_representation:
case VelRepr.Body:
L_a_bias_WL = v̇_bias_WL
case VelRepr.Inertial:
C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841
C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841
L_H_C = L_H_W = jax.vmap(jaxsim.math.Transform.inverse)(W_H_L) # noqa: F841
L_v_LC = L_v_LW = jax.vmap( # noqa: F841
lambda i: -js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
)
)(jnp.arange(model.number_of_links()))
L_a_bias_WL = jax.vmap(
lambda i: other_representation_to_body(
C_v̇_WL=C_v̇_WL[i],
C_v_WC=C_v_WC,
L_H_C=L_H_C[i],
L_v_LC=L_v_LC[i],
)
)(jnp.arange(model.number_of_links()))
case VelRepr.Mixed:
C_v̇_WL = LW_v̇_bias_WL = v̇_bias_WL # noqa: F841
C_v_WC = LW_v_W_LW = jax.vmap( # noqa: F841
lambda i: js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Mixed
)
.at[3:6]
.set(jnp.zeros(3))
)(jnp.arange(model.number_of_links()))
L_H_C = L_H_LW = jax.vmap( # noqa: F841
lambda W_H_L: jaxsim.math.Transform.inverse(
W_H_L.at[0:3, 3].set(jnp.zeros(3))
)
)(W_H_L)
L_v_LC = L_v_L_LW = jax.vmap( # noqa: F841
lambda i: -js.link.velocity(
model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body
)
.at[0:3]
.set(jnp.zeros(3))
)(jnp.arange(model.number_of_links()))
L_a_bias_WL = jax.vmap(
lambda i: other_representation_to_body(
C_v̇_WL=C_v̇_WL[i],
C_v_WC=C_v_WC[i],
L_H_C=L_H_C[i],
L_v_LC=L_v_LC[i],
)
)(jnp.arange(model.number_of_links()))
case _:
raise ValueError(data.velocity_representation)
# Compute the bias of the 6D momentum derivative.
def bias_momentum_derivative_term(
link_index: jtp.Int, L_a_bias_WL: jtp.Vector
) -> jtp.Vector:
# Get the body-fixed 6D inertia matrix.
L_M_L = js.link.spatial_inertia(model=model, link_index=link_index)
# Compute the body-fixed 6D velocity.
L_v_WL = js.link.velocity(
model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Body
)
# Compute the world-to-link transformations for 6D forces.
W_Xf_L = jaxsim.math.Adjoint.from_transform(
transform=W_H_L[link_index], inverse=True
).T
# Compute the contribution of the link to the bias acceleration of the CoM.
W_ḣ_bias_link_contribution = W_Xf_L @ (
L_M_L @ L_a_bias_WL + jaxsim.math.Cross.vx_star(L_v_WL) @ L_M_L @ L_v_WL
)
return W_ḣ_bias_link_contribution
# Sum the contributions of all links to the bias acceleration of the CoM.
W_ḣ_bias = jax.vmap(bias_momentum_derivative_term)(
jnp.arange(model.number_of_links()), L_a_bias_WL
).sum(axis=0)
# Compute the total mass of the model.
m = js.model.total_mass(model=model)
# Compute the position of the CoM.
W_p_CoM = com_position(model=model, data=data)
match data.velocity_representation:
# G := G[W] = (W_p_CoM, [W])
case VelRepr.Inertial | VelRepr.Mixed:
W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
GW_Xf_W = jaxsim.math.Adjoint.from_transform(W_H_GW).T
GW_ḣ_bias = GW_Xf_W @ W_ḣ_bias
GW_v̇l_com_bias = GW_ḣ_bias[0:3] / m
return GW_v̇l_com_bias
# G := G[B] = (W_p_CoM, [B])
case VelRepr.Body:
GB_Xf_W = jaxsim.math.Adjoint.from_transform(
transform=data._base_transform.at[0:3].set(W_p_CoM)
).T
GB_ḣ_bias = GB_Xf_W @ W_ḣ_bias
GB_v̇l_com_bias = GB_ḣ_bias[0:3] / m
return GB_v̇l_com_bias
case _:
raise ValueError(data.velocity_representation)
================================================
FILE: src/jaxsim/api/common.py
================================================
import abc
import contextlib
import dataclasses
import enum
import functools
from collections.abc import Callable, Iterator
from typing import ParamSpec, TypeVar
import jax
import jax.numpy as jnp
import jax_dataclasses
from jax_dataclasses import Static
import jaxsim.typing as jtp
from jaxsim.math import Adjoint
from jaxsim.utils import JaxsimDataclass, Mutability
try:
from typing import Self
except ImportError:
from typing_extensions import Self
_P = ParamSpec("_P")
_R = TypeVar("_R")
def named_scope(fn, name: str | None = None) -> Callable[_P, _R]:
"""Apply a JAX named scope to a function for improved profiling and clarity."""
@functools.wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
with jax.named_scope(name or fn.__name__):
return fn(*args, **kwargs)
return wrapper
@enum.unique
class VelRepr(enum.IntEnum):
"""
Enumeration of all supported 6D velocity representations.
"""
Body = enum.auto()
Mixed = enum.auto()
Inertial = enum.auto()
@jax_dataclasses.pytree_dataclass
class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
"""
Base class for model data structures with velocity representation.
"""
velocity_representation: Static[VelRepr] = dataclasses.field(
default=VelRepr.Inertial, kw_only=True
)
@contextlib.contextmanager
def switch_velocity_representation(
self, velocity_representation: VelRepr
) -> Iterator[Self]:
"""
Context manager to temporarily switch the velocity representation.
Args:
velocity_representation: The new velocity representation.
Yields:
The same object with the new velocity representation.
"""
original_representation = self.velocity_representation
try:
# First, we replace the velocity representation.
with self.mutable_context(
mutability=Mutability.MUTABLE_NO_VALIDATION,
restore_after_exception=True,
):
self.velocity_representation = velocity_representation
# Then, we yield the data with changed representation.
# We run this in a mutable context with restoration so that any exception
# occurring, we restore the original object in case it was modified.
with self.mutable_context(
mutability=self.mutability(), restore_after_exception=True
):
yield self
finally:
with self.mutable_context(
mutability=Mutability.MUTABLE_NO_VALIDATION,
restore_after_exception=True,
):
self.velocity_representation = original_representation
@staticmethod
@functools.partial(jax.jit, static_argnames=["other_representation", "is_force"])
def inertial_to_other_representation(
array: jtp.Array,
other_representation: VelRepr,
transform: jtp.Matrix,
*,
is_force: bool,
) -> jtp.Array:
r"""
Convert a 6D quantity from inertial-fixed to another representation.
Args:
array: The 6D quantity to convert.
other_representation: The representation to convert to.
transform:
The :math:`W \mathbf{H}_O` transform, where :math:`O` is the
reference frame of the other representation.
is_force: Whether the quantity is a 6D force or a 6D velocity.
Returns:
The 6D quantity in the other representation.
"""
W_array = array
W_H_O = transform
match other_representation:
case VelRepr.Inertial:
return W_array
case VelRepr.Body:
if not is_force:
O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True)
O_array = jnp.einsum("...ij,...j->...i", O_Xv_W, W_array)
else:
O_Xf_W = Adjoint.from_transform(transform=W_H_O).swapaxes(-1, -2)
O_array = jnp.einsum("...ij,...j->...i", O_Xf_W, W_array)
return O_array
case VelRepr.Mixed:
W_H_OW = W_H_O.at[..., 0:3, 0:3].set(jnp.eye(3))
if not is_force:
OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True)
OW_array = jnp.einsum("...ij,...j->...i", OW_Xv_W, W_array)
else:
OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).swapaxes(-1, -2)
OW_array = jnp.einsum("...ij,...j->...i", OW_Xf_W, W_array)
return OW_array
case _:
raise ValueError(other_representation)
@staticmethod
@functools.partial(jax.jit, static_argnames=["other_representation", "is_force"])
def other_representation_to_inertial(
array: jtp.Array,
other_representation: VelRepr,
transform: jtp.Matrix,
*,
is_force: bool,
) -> jtp.Array:
r"""
Convert a 6D quantity from another representation to inertial-fixed.
Args:
array: The 6D quantity to convert.
other_representation: The representation to convert from.
transform:
The `math:W \mathbf{H}_O` transform, where `math:O` is the
reference frame of the other representation.
is_force: Whether the quantity is a 6D force or a 6D velocity.
Returns:
The 6D quantity in the inertial-fixed representation.
"""
O_array = array
W_H_O = transform
match other_representation:
case VelRepr.Inertial:
return O_array
case VelRepr.Body:
if not is_force:
W_Xv_O = Adjoint.from_transform(W_H_O)
W_array = jnp.einsum("...ij,...j->...i", W_Xv_O, O_array)
else:
W_Xf_O = Adjoint.from_transform(
transform=W_H_O, inverse=True
).swapaxes(-1, -2)
W_array = jnp.einsum("...ij,...j->...i", W_Xf_O, O_array)
return W_array
case VelRepr.Mixed:
W_H_OW = W_H_O.at[..., 0:3, 0:3].set(jnp.eye(3))
if not is_force:
W_Xv_BW = Adjoint.from_transform(W_H_OW)
W_array = jnp.einsum("...ij,...j->...i", W_Xv_BW, O_array)
else:
W_Xf_BW = Adjoint.from_transform(
transform=W_H_OW, inverse=True
).swapaxes(-1, -2)
W_array = jnp.einsum("...ij,...j->...i", W_Xf_BW, O_array)
return W_array
case _:
raise ValueError(other_representation)
================================================
FILE: src/jaxsim/api/contact.py
================================================
from __future__ import annotations
import functools
import jax
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.exceptions
import jaxsim.typing as jtp
from jaxsim import logging
from jaxsim.math import Adjoint, Cross, Transform
from jaxsim.rbda.contacts import SoftContacts
from .common import VelRepr
@jax.jit
@js.common.named_scope
def collidable_point_kinematics(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> tuple[jtp.Matrix, jtp.Matrix]:
"""
Compute the position and 3D velocity of the collidable points in the world frame.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The position and velocity of the collidable points in the world frame.
Note:
The collidable point velocity is the plain coordinate derivative of the position.
If we attach a frame C = (p_C, [C]) to the collidable point, it corresponds to
the linear component of the mixed 6D frame velocity.
"""
W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
model=model,
link_transforms=data._link_transforms,
link_velocities=data._link_velocities,
)
return W_p_Ci, W_ṗ_Ci
@jax.jit
@js.common.named_scope
def collidable_point_positions(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
"""
Compute the position of the collidable points in the world frame.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The position of the collidable points in the world frame.
"""
W_p_Ci, _ = collidable_point_kinematics(model=model, data=data)
return W_p_Ci
@jax.jit
@js.common.named_scope
def collidable_point_velocities(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
"""
Compute the 3D velocity of the collidable points in the world frame.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The 3D velocity of the collidable points.
"""
_, W_ṗ_Ci = collidable_point_kinematics(model=model, data=data)
return W_ṗ_Ci
@functools.partial(jax.jit, static_argnames=["link_names"])
@js.common.named_scope
def in_contact(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_names: tuple[str, ...] | None = None,
) -> jtp.Vector:
"""
Return whether the links are in contact with the terrain.
Args:
model: The model to consider.
data: The data of the considered model.
link_names:
The names of the links to consider. If None, all links are considered.
Returns:
A boolean vector indicating whether the links are in contact with the terrain.
"""
if link_names is not None and set(link_names).difference(model.link_names()):
raise ValueError("One or more link names are not part of the model")
# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)
parent_link_idx_of_enabled_collidable_points = jnp.array(
model.kin_dyn_parameters.contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]
W_p_Ci = collidable_point_positions(model=model, data=data)
terrain_height = jax.vmap(lambda x, y: model.terrain.height(x=x, y=y))(
W_p_Ci[:, 0], W_p_Ci[:, 1]
)
below_terrain = W_p_Ci[:, 2] <= terrain_height
link_idxs = (
js.link.names_to_idxs(link_names=link_names, model=model)
if link_names is not None
else jnp.arange(model.number_of_links())
)
links_in_contact = jax.vmap(
lambda link_index: jnp.where(
parent_link_idx_of_enabled_collidable_points == link_index,
below_terrain,
jnp.zeros_like(below_terrain, dtype=bool),
).any()
)(link_idxs)
return links_in_contact
def estimate_good_soft_contacts_parameters(
*args, **kwargs
) -> jaxsim.rbda.contacts.ContactParamsTypes:
"""
Estimate good soft contacts parameters. Deprecated, use `estimate_good_contact_parameters` instead.
"""
msg = "This method is deprecated, please use `{}`."
logging.warning(msg.format(estimate_good_contact_parameters.__name__))
return estimate_good_contact_parameters(*args, **kwargs)
def estimate_good_contact_parameters(
model: js.model.JaxSimModel,
*,
standard_gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
static_friction_coefficient: jtp.FloatLike = 0.5,
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
damping_ratio: jtp.FloatLike = 1.0,
max_penetration: jtp.FloatLike | None = None,
) -> jaxsim.rbda.contacts.ContactParamsTypes:
"""
Estimate good contact parameters.
Args:
model: The model to consider.
standard_gravity: The standard gravity acceleration.
static_friction_coefficient: The static friction coefficient.
number_of_active_collidable_points_steady_state:
The number of active collidable points in steady state.
damping_ratio: The damping ratio.
max_penetration: The maximum penetration allowed.
Returns:
The estimated good contacts parameters.
Note:
This is primarily a convenience function for soft-like contact models.
However, it provides with some good default parameters also for the other ones.
Note:
This method provides a good set of contacts parameters.
The user is encouraged to fine-tune the parameters based on the
specific application.
"""
if max_penetration is None:
zero_data = js.data.JaxSimModelData.build(model=model)
W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
if model.floating_base():
W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
W_pz_CoM = W_pz_CoM - W_pz_C.min()
# Consider as default a 1% of the model center of mass height.
max_penetration = 0.01 * W_pz_CoM
nc = number_of_active_collidable_points_steady_state
return model.contact_model._parameters_class().build_default_from_jaxsim_model(
model=model,
standard_gravity=standard_gravity,
static_friction_coefficient=static_friction_coefficient,
max_penetration=max_penetration,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
)
@jax.jit
@js.common.named_scope
def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:
r"""
Return the pose of the enabled collidable points.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The stacked SE(3) matrices of all enabled collidable points.
Note:
Each collidable point is implicitly associated with a frame
:math:`C = ({}^W p_C, [L])`, where :math:`{}^W p_C` is the position of the
collidable point and :math:`[L]` is the orientation frame of the link it is
rigidly attached to.
"""
# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)
parent_link_idx_of_enabled_collidable_points = jnp.array(
model.kin_dyn_parameters.contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]
# Get the transforms of the parent link of all collidable points.
W_H_L = data._link_transforms[parent_link_idx_of_enabled_collidable_points]
L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
indices_of_enabled_collidable_points
]
# Build the link-to-point transform from the displacement between the link frame L
# and the implicit contact frame C.
L_H_C = jax.vmap(jnp.eye(4).at[0:3, 3].set)(L_p_Ci)
# Compose the work-to-link and link-to-point transforms.
return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C)
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
@js.common.named_scope
def jacobian(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
output_vel_repr: VelRepr | None = None,
) -> jtp.Array:
r"""
Return the free-floating Jacobian of the enabled collidable points.
Args:
model: The model to consider.
data: The data of the considered model.
output_vel_repr:
The output velocity representation of the free-floating jacobian.
Returns:
The stacked :math:`6 \times (6+n)` free-floating jacobians of the frames associated to the
enabled collidable points.
Note:
Each collidable point is implicitly associated with a frame
:math:`C = ({}^W p_C, [L])`, where :math:`{}^W p_C` is the position of the
collidable point and :math:`[L]` is the orientation frame of the link it is
rigidly attached to.
"""
output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)
# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)
parent_link_idx_of_enabled_collidable_points = jnp.array(
model.kin_dyn_parameters.contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]
# Compute the Jacobians of all links.
W_J_WL = js.model.generalized_free_floating_jacobian(
model=model, data=data, output_vel_repr=VelRepr.Inertial
)
# Compute the contact Jacobian.
# In inertial-fixed output representation, the Jacobian of the parent link is also
# the Jacobian of the frame C implicitly associated with the collidable point.
W_J_WC = W_J_WL[parent_link_idx_of_enabled_collidable_points]
# Adjust the output representation.
match output_vel_repr:
case VelRepr.Inertial:
O_J_WC = W_J_WC
case VelRepr.Body:
W_H_C = transforms(model=model, data=data)
def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
C_X_W = jaxsim.math.Adjoint.from_transform(
transform=W_H_C, inverse=True
)
C_J_WC = C_X_W @ W_J_WC
return C_J_WC
O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC)
case VelRepr.Mixed:
W_H_C = transforms(model=model, data=data)
def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))
CW_X_W = jaxsim.math.Adjoint.from_transform(
transform=W_H_CW, inverse=True
)
CW_J_WC = CW_X_W @ W_J_WC
return CW_J_WC
O_J_WC = jax.vmap(mixed_jacobian)(W_H_C, W_J_WC)
case _:
raise ValueError(output_vel_repr)
return O_J_WC
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
@js.common.named_scope
def jacobian_derivative(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
output_vel_repr: VelRepr | None = None,
) -> jtp.Matrix:
r"""
Compute the derivative of the free-floating jacobian of the enabled collidable points.
Args:
model: The model to consider.
data: The data of the considered model.
output_vel_repr:
The output velocity representation of the free-floating jacobian derivative.
Returns:
The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the enabled collidable points.
Note:
The input representation of the free-floating jacobian derivative is the active
velocity representation.
"""
output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)
# Get the index of the parent link and the position of the collidable point.
parent_link_idx_of_enabled_collidable_points = jnp.array(
model.kin_dyn_parameters.contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]
L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
indices_of_enabled_collidable_points
]
# Get the transforms of all the parent links.
W_H_Li = data._link_transforms
# Get the link velocities.
W_v_WLi = data._link_velocities
# =====================================================
# Compute quantities to adjust the input representation
# =====================================================
def compute_T(model: js.model.JaxSimModel, X: jtp.Matrix) -> jtp.Matrix:
In = jnp.eye(model.dofs())
T = jax.scipy.linalg.block_diag(X, In)
return T
def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix:
On = jnp.zeros(shape=(model.dofs(), model.dofs()))
Ṫ = jax.scipy.linalg.block_diag(Ẋ, On)
return Ṫ
# Compute the operator to change the representation of ν, and its
# time derivative.
match data.velocity_representation:
case VelRepr.Inertial:
W_H_W = jnp.eye(4)
W_X_W = Adjoint.from_transform(transform=W_H_W)
W_Ẋ_W = jnp.zeros((6, 6))
T = compute_T(model=model, X=W_X_W)
Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W)
case VelRepr.Body:
W_H_B = data._base_transform
W_X_B = Adjoint.from_transform(transform=W_H_B)
B_v_WB = data.base_velocity
B_vx_WB = Cross.vx(B_v_WB)
W_Ẋ_B = W_X_B @ B_vx_WB
T = compute_T(model=model, X=W_X_B)
Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B)
case VelRepr.Mixed:
W_H_B = data._base_transform
W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
W_X_BW = Adjoint.from_transform(transform=W_H_BW)
BW_v_WB = data.base_velocity
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
BW_vx_W_BW = Cross.vx(BW_v_W_BW)
W_Ẋ_BW = W_X_BW @ BW_vx_W_BW
T = compute_T(model=model, X=W_X_BW)
Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW)
case _:
raise ValueError(data.velocity_representation)
# =====================================================
# Compute quantities to adjust the output representation
# =====================================================
with data.switch_velocity_representation(VelRepr.Inertial):
# Compute the Jacobian of the parent link in inertial representation.
W_J_WL_W = js.model.generalized_free_floating_jacobian(
model=model,
data=data,
)
# Compute the Jacobian derivative of the parent link in inertial representation.
W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative(
model=model,
data=data,
)
def compute_O_J̇_WC_I(
L_p_C: jtp.Vector,
parent_link_idx: jtp.Int,
W_H_L: jtp.Matrix,
) -> jtp.Matrix:
match output_vel_repr:
case VelRepr.Inertial:
O_X_W = W_X_W = Adjoint.from_transform( # noqa: F841
transform=jnp.eye(4)
)
O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6)) # noqa: F841
case VelRepr.Body:
L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
W_H_C = W_H_L[parent_link_idx] @ L_H_C
O_X_W = C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
W_v_WC = W_v_WLi[parent_link_idx]
W_vx_WC = Cross.vx(W_v_WC)
O_Ẋ_W = C_Ẋ_W = -C_X_W @ W_vx_WC # noqa: F841
case VelRepr.Mixed:
L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
W_H_C = W_H_L[parent_link_idx] @ L_H_C
W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))
CW_H_W = Transform.inverse(W_H_CW)
O_X_W = CW_X_W = Adjoint.from_transform(transform=CW_H_W)
CW_v_WC = CW_X_W @ W_v_WLi[parent_link_idx]
W_v_W_CW = jnp.zeros(6).at[0:3].set(CW_v_WC[0:3])
W_vx_W_CW = Cross.vx(W_v_W_CW)
O_Ẋ_W = CW_Ẋ_W = -CW_X_W @ W_vx_W_CW # noqa: F841
case _:
raise ValueError(output_vel_repr)
O_J̇_WC_I = jnp.zeros(shape=(6, 6 + model.dofs()))
O_J̇_WC_I += O_Ẋ_W @ W_J_WL_W[parent_link_idx] @ T
O_J̇_WC_I += O_X_W @ W_J̇_WL_W[parent_link_idx] @ T
O_J̇_WC_I += O_X_W @ W_J_WL_W[parent_link_idx] @ Ṫ
return O_J̇_WC_I
O_J̇_WC = jax.vmap(compute_O_J̇_WC_I, in_axes=(0, 0, None))(
L_p_Ci, parent_link_idx_of_enabled_collidable_points, W_H_Li
)
return O_J̇_WC
@jax.jit
@js.common.named_scope
def link_contact_forces(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_forces: jtp.MatrixLike | None = None,
joint_torques: jtp.VectorLike | None = None,
) -> tuple[jtp.Matrix, dict[str, jtp.Matrix]]:
"""
Compute the 6D contact forces of all links of the model in inertial representation.
Args:
model: The model to consider.
data: The data of the considered model.
link_forces:
The 6D external forces to apply to the links expressed in inertial representation
joint_torques:
The joint torques acting on the joints.
Returns:
A `(nL, 6)` array containing the stacked 6D contact forces of the links,
expressed in inertial representation.
"""
# Compute the contact forces for each collidable point with the active contact model.
W_f_C, aux_dict = model.contact_model.compute_contact_forces(
model=model,
data=data,
**(
dict(link_forces=link_forces, joint_force_references=joint_torques)
if not isinstance(model.contact_model, SoftContacts)
else {}
),
)
# Compute the 6D forces applied to the links equivalent to the forces applied
# to the frames associated to the collidable points.
W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C)
return W_f_L, aux_dict
def link_forces_from_contact_forces(
model: js.model.JaxSimModel,
*,
contact_forces: jtp.MatrixLike,
) -> jtp.Matrix:
"""
Compute the link forces from the contact forces.
Args:
model: The robot model considered by the contact model.
contact_forces: The contact forces computed by the contact model.
Returns:
The 6D contact forces applied to the links and expressed in the frame of
the velocity representation of data.
"""
# Get the object storing the contact parameters of the model.
contact_parameters = model.kin_dyn_parameters.contact_parameters
# Extract the indices corresponding to the enabled collidable points.
indices_of_enabled_collidable_points = (
contact_parameters.indices_of_enabled_collidable_points
)
# Convert the contact forces to a JAX array.
W_f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())
# Construct the vector defining the parent link index of each collidable point.
# We use this vector to sum the 6D forces of all collidable points rigidly
# attached to the same link.
parent_link_index_of_collidable_points = jnp.array(
contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]
# Create the mask that associate each collidable point to their parent link.
# We use this mask to sum the collidable points to the right link.
mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
model.number_of_links()
)
# Sum the forces of all collidable points rigidly attached to a body.
# Since the contact forces W_f_C are expressed in the world frame,
# we don't need any coordinate transformation.
W_f_L = mask.T @ W_f_C
return W_f_L
================================================
FILE: src/jaxsim/api/data.py
================================================
from __future__ import annotations
import dataclasses
import functools
from collections.abc import Sequence
try:
from typing import Self, override
except ImportError:
from typing_extensions import override, Self
import jax
import jax.numpy as jnp
import jax.scipy.spatial.transform
import jax_dataclasses
import jaxsim.api as js
import jaxsim.math
import jaxsim.rbda
import jaxsim.typing as jtp
from . import common
from .common import VelRepr
@jax_dataclasses.pytree_dataclass
class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
"""
Class storing the state of the physics model dynamics.
Attributes:
joint_positions: The vector of joint positions.
joint_velocities: The vector of joint velocities.
base_position: The 3D position of the base link.
base_quaternion: The quaternion defining the orientation of the base link.
base_linear_velocity:
The linear velocity of the base link in inertial-fixed representation.
base_angular_velocity:
The angular velocity of the base link in inertial-fixed representation.
base_transform: The base transform.
joint_transforms: The joint transforms.
link_transforms: The link transforms.
link_velocities: The link velocities in inertial-fixed representation.
"""
# Joint state
_joint_positions: jtp.Vector
_joint_velocities: jtp.Vector
# Base state
_base_quaternion: jtp.Vector
_base_linear_velocity: jtp.Vector
_base_angular_velocity: jtp.Vector
_base_position: jtp.Vector
# Cached computations.
_base_transform: jtp.Matrix = dataclasses.field(repr=False, default=None)
_joint_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
_link_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
_link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None)
# Extended state for soft and rigid contact models.
contact_state: dict[str, jtp.Array] = dataclasses.field(default_factory=dict)
@staticmethod
def build(
model: js.model
gitextract_5j4outfv/
├── .devcontainer/
│ ├── Dockerfile
│ └── devcontainer.json
├── .gitattributes
├── .github/
│ ├── CODEOWNERS
│ ├── dependabot.yml
│ ├── release.yml
│ └── workflows/
│ ├── ci_cd.yml
│ ├── gpu_benchmark.yml
│ ├── pixi.yml
│ └── read_the_docs.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── CITATION.cff
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs/
│ ├── Makefile
│ ├── conf.py
│ ├── examples.rst
│ ├── guide/
│ │ ├── configuration.rst
│ │ └── install.rst
│ ├── index.rst
│ ├── make.bat
│ └── modules/
│ ├── api.rst
│ ├── math.rst
│ ├── mujoco.rst
│ ├── parsers.rst
│ ├── rbda.rst
│ ├── typing.rst
│ └── utils.rst
├── environment.yml
├── examples/
│ ├── .gitattributes
│ ├── .gitignore
│ ├── README.md
│ ├── assets/
│ │ ├── build_cartpole_urdf.py
│ │ └── cartpole.urdf
│ ├── jaxsim_as_multibody_dynamics_library.ipynb
│ ├── jaxsim_as_physics_engine.ipynb
│ ├── jaxsim_as_physics_engine_advanced.ipynb
│ └── jaxsim_for_robot_controllers.ipynb
├── pyproject.toml
├── src/
│ └── jaxsim/
│ ├── __init__.py
│ ├── api/
│ │ ├── __init__.py
│ │ ├── actuation_model.py
│ │ ├── com.py
│ │ ├── common.py
│ │ ├── contact.py
│ │ ├── data.py
│ │ ├── frame.py
│ │ ├── integrators.py
│ │ ├── joint.py
│ │ ├── kin_dyn_parameters.py
│ │ ├── link.py
│ │ ├── model.py
│ │ ├── ode.py
│ │ └── references.py
│ ├── exceptions.py
│ ├── logging.py
│ ├── math/
│ │ ├── __init__.py
│ │ ├── adjoint.py
│ │ ├── cross.py
│ │ ├── inertia.py
│ │ ├── joint_model.py
│ │ ├── quaternion.py
│ │ ├── rotation.py
│ │ ├── skew.py
│ │ ├── transform.py
│ │ └── utils.py
│ ├── mujoco/
│ │ ├── __init__.py
│ │ ├── __main__.py
│ │ ├── loaders.py
│ │ ├── model.py
│ │ ├── utils.py
│ │ └── visualizer.py
│ ├── parsers/
│ │ ├── __init__.py
│ │ ├── descriptions/
│ │ │ ├── __init__.py
│ │ │ ├── collision.py
│ │ │ ├── joint.py
│ │ │ ├── link.py
│ │ │ └── model.py
│ │ ├── kinematic_graph.py
│ │ └── rod/
│ │ ├── __init__.py
│ │ ├── meshes.py
│ │ ├── parser.py
│ │ └── utils.py
│ ├── rbda/
│ │ ├── __init__.py
│ │ ├── aba.py
│ │ ├── aba_parallel.py
│ │ ├── actuation/
│ │ │ ├── __init__.py
│ │ │ └── common.py
│ │ ├── collidable_points.py
│ │ ├── contacts/
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── relaxed_rigid.py
│ │ │ ├── rigid.py
│ │ │ └── soft.py
│ │ ├── crba.py
│ │ ├── forward_kinematics.py
│ │ ├── forward_kinematics_parallel.py
│ │ ├── jacobian.py
│ │ ├── kinematic_constraints.py
│ │ ├── mass_inverse.py
│ │ ├── rnea.py
│ │ └── utils.py
│ ├── terrain/
│ │ ├── __init__.py
│ │ └── terrain.py
│ ├── typing.py
│ └── utils/
│ ├── __init__.py
│ ├── jaxsim_dataclass.py
│ ├── tracing.py
│ └── wrappers.py
└── tests/
├── __init__.py
├── assets/
│ ├── 4_bar_opened.urdf
│ ├── cube.stl
│ ├── double_pendulum.sdf
│ ├── mixed_shapes_robot.urdf
│ └── test_cube.urdf
├── conftest.py
├── test_actuation.py
├── test_api_com.py
├── test_api_contact.py
├── test_api_data.py
├── test_api_frame.py
├── test_api_joint.py
├── test_api_link.py
├── test_api_model.py
├── test_api_model_hw_parametrization.py
├── test_automatic_differentiation.py
├── test_benchmark.py
├── test_exceptions.py
├── test_meshes.py
├── test_pytree.py
├── test_simulations.py
├── test_visualizer.py
└── utils.py
SYMBOL INDEX (670 symbols across 75 files)
FILE: src/jaxsim/__init__.py
function _jnp_options (line 6) | def _jnp_options() -> None:
function _np_options (line 44) | def _np_options() -> None:
function _is_editable (line 50) | def _is_editable() -> bool:
function _get_default_logging_level (line 70) | def _get_default_logging_level() -> logging.LoggingLevel:
FILE: src/jaxsim/api/actuation_model.py
function compute_resultant_torques (line 7) | def compute_resultant_torques(
function tn_curve_fn (line 101) | def tn_curve_fn(
FILE: src/jaxsim/api/com.py
function com_position (line 13) | def com_position(
function com_linear_velocity (line 50) | def com_linear_velocity(
function centroidal_momentum (line 81) | def centroidal_momentum(
function centroidal_momentum_jacobian (line 109) | def centroidal_momentum_jacobian(
function locked_centroidal_spatial_inertia (line 158) | def locked_centroidal_spatial_inertia(
function average_centroidal_velocity (line 196) | def average_centroidal_velocity(
function average_centroidal_velocity_jacobian (line 224) | def average_centroidal_velocity_jacobian(
function bias_acceleration (line 251) | def bias_acceleration(
FILE: src/jaxsim/api/common.py
function named_scope (line 28) | def named_scope(fn, name: str | None = None) -> Callable[_P, _R]:
class VelRepr (line 40) | class VelRepr(enum.IntEnum):
class ModelDataWithVelocityRepresentation (line 51) | class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
method switch_velocity_representation (line 61) | def switch_velocity_representation(
method inertial_to_other_representation (line 102) | def inertial_to_other_representation(
method other_representation_to_inertial (line 162) | def other_representation_to_inertial(
FILE: src/jaxsim/api/contact.py
function collidable_point_kinematics (line 20) | def collidable_point_kinematics(
function collidable_point_positions (line 50) | def collidable_point_positions(
function collidable_point_velocities (line 71) | def collidable_point_velocities(
function in_contact (line 92) | def in_contact(
function estimate_good_soft_contacts_parameters (line 148) | def estimate_good_soft_contacts_parameters(
function estimate_good_contact_parameters (line 160) | def estimate_good_contact_parameters(
function transforms (line 216) | def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelDat...
function jacobian (line 260) | def jacobian(
function jacobian_derivative (line 353) | def jacobian_derivative(
function link_contact_forces (line 516) | def link_contact_forces(
function link_forces_from_contact_forces (line 557) | def link_forces_from_contact_forces(
FILE: src/jaxsim/api/data.py
class JaxSimModelData (line 27) | class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
method build (line 66) | def build(
method zero (line 206) | def zero(
method joint_positions (line 229) | def joint_positions(self) -> jtp.Vector:
method joint_velocities (line 239) | def joint_velocities(self) -> jtp.Vector:
method base_quaternion (line 249) | def base_quaternion(self) -> jtp.Vector:
method base_position (line 259) | def base_position(self) -> jtp.Vector:
method base_orientation (line 269) | def base_orientation(self) -> jtp.Matrix:
method base_velocity (line 290) | def base_velocity(self) -> jtp.Vector:
method generalized_position (line 316) | def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:
method generalized_velocity (line 328) | def generalized_velocity(self) -> jtp.Vector:
method base_transform (line 345) | def base_transform(self) -> jtp.Matrix:
method reset_base_quaternion (line 360) | def reset_base_quaternion(
method reset_base_pose (line 383) | def reset_base_pose(
method replace (line 407) | def replace(
method valid (line 527) | def valid(self, model: js.model.JaxSimModel) -> bool:
function random_model_data (line 555) | def random_model_data(
FILE: src/jaxsim/api/frame.py
function idx_of_parent_link (line 21) | def idx_of_parent_link(
function name_to_idx (line 51) | def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> jtp....
function idx_to_name (line 76) | def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike...
function names_to_idxs (line 104) | def names_to_idxs(
function idxs_to_names (line 123) | def idxs_to_names(
function transform (line 147) | def transform(
function velocity (line 189) | def velocity(
function jacobian (line 240) | def jacobian(
function jacobian_derivative (line 330) | def jacobian_derivative(
FILE: src/jaxsim/api/integrators.py
function semi_implicit_euler_integration (line 14) | def semi_implicit_euler_integration(
function rk4_integration (line 90) | def rk4_integration(
function rk4fast_integration (line 158) | def rk4fast_integration(
FILE: src/jaxsim/api/joint.py
function name_to_idx (line 18) | def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp....
function idx_to_name (line 44) | def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike...
function names_to_idxs (line 67) | def names_to_idxs(
function idxs_to_names (line 86) | def idxs_to_names(
function position_limit (line 111) | def position_limit(
function position_limits (line 148) | def position_limits(
function random_joint_positions (line 184) | def random_joint_positions(
FILE: src/jaxsim/api/kin_dyn_parameters.py
class KinDynParameters (line 23) | class KinDynParameters(JaxsimDataclass):
method motion_subspaces (line 74) | def motion_subspaces(self) -> jtp.Matrix:
method parent_array (line 81) | def parent_array(self) -> jtp.Vector:
method support_body_array_bool (line 88) | def support_body_array_bool(self) -> jtp.Matrix:
method level_nodes (line 95) | def level_nodes(self) -> jtp.Matrix:
method level_mask (line 104) | def level_mask(self) -> jtp.Matrix:
method _compute_tree_levels (line 112) | def _compute_tree_levels(
method build (line 159) | def build(
method __eq__ (line 367) | def __eq__(self, other: KinDynParameters) -> bool:
method __hash__ (line 373) | def __hash__(self) -> int:
method number_of_links (line 389) | def number_of_links(self) -> int:
method number_of_joints (line 399) | def number_of_joints(self) -> int:
method number_of_frames (line 409) | def number_of_frames(self) -> int:
method support_body_array (line 419) | def support_body_array(self, link_index: jtp.IntLike) -> jtp.Vector:
method links_spatial_inertia (line 443) | def links_spatial_inertia(self) -> jtp.Array:
method tree_transforms (line 454) | def tree_transforms(self) -> jtp.Array:
method joint_transforms (line 478) | def joint_transforms(
method set_link_mass (line 538) | def set_link_mass(
method set_link_inertia (line 558) | def set_link_inertia(
class JointParameters (line 584) | class JointParameters(JaxsimDataclass):
method build_from_joint_description (line 614) | def build_from_joint_description(
class LinkParameters (line 656) | class LinkParameters(JaxsimDataclass):
method build_from_spatial_inertia (line 682) | def build_from_spatial_inertia(index: jtp.IntLike, M: jtp.Matrix) -> L...
method build_from_inertial_parameters (line 708) | def build_from_inertial_parameters(
method build_from_flat_parameters (line 735) | def build_from_flat_parameters(
method flat_parameters (line 759) | def flat_parameters(params: LinkParameters) -> jtp.Vector:
method inertia_tensor (line 783) | def inertia_tensor(params: LinkParameters) -> jtp.Matrix:
method spatial_inertia (line 799) | def spatial_inertia(params: LinkParameters) -> jtp.Matrix:
method flatten_inertia_tensor (line 817) | def flatten_inertia_tensor(I: jtp.Matrix) -> jtp.Vector:
method unflatten_inertia_tensor (line 831) | def unflatten_inertia_tensor(inertia_elements: jtp.Vector) -> jtp.Matrix:
class ContactParameters (line 847) | class ContactParameters(JaxsimDataclass):
method indices_of_enabled_collidable_points (line 874) | def indices_of_enabled_collidable_points(self) -> npt.NDArray:
method build_from (line 881) | def build_from(model_description: ModelDescription) -> ContactParameters:
class FrameParameters (line 925) | class FrameParameters(JaxsimDataclass):
method build_from (line 948) | def build_from(model_description: ModelDescription) -> FrameParameters:
class LinkParametrizableShape (line 990) | class LinkParametrizableShape:
class HwLinkMetadata (line 1003) | class HwLinkMetadata(JaxsimDataclass):
method empty (line 1042) | def empty(cls) -> HwLinkMetadata:
method compute_mesh_inertia (line 1060) | def compute_mesh_inertia(
method precompute_mesh_moments (line 1132) | def precompute_mesh_moments(vertices: np.ndarray, faces: np.ndarray) -...
method compute_mesh_inertia_from_moments (line 1178) | def compute_mesh_inertia_from_moments(
method compute_mass_and_inertia (line 1222) | def compute_mass_and_inertia(
method _convert_scaling_to_3d_vector (line 1311) | def _convert_scaling_to_3d_vector(
method compute_contact_points (line 1353) | def compute_contact_points(
method compute_inertia_link (line 1427) | def compute_inertia_link(I_com, L_H_G) -> jtp.Matrix:
method apply_scaling (line 1436) | def apply_scaling(
class ScalingFactors (line 1529) | class ScalingFactors(JaxsimDataclass):
class ConstraintType (line 1543) | class ConstraintType:
class ConstraintMap (line 1554) | class ConstraintMap(JaxsimDataclass):
method add_constraint (line 1582) | def add_constraint(
FILE: src/jaxsim/api/link.py
function name_to_idx (line 23) | def name_to_idx(model: js.model.JaxSimModel, *, link_name: str) -> jtp.Int:
function idx_to_name (line 45) | def idx_to_name(model: js.model.JaxSimModel, *, link_index: jtp.IntLike)...
function names_to_idxs (line 67) | def names_to_idxs(
function idxs_to_names (line 86) | def idxs_to_names(
function mass (line 109) | def mass(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp...
function spatial_inertia (line 133) | def spatial_inertia(
function transform (line 164) | def transform(
function com_position (line 194) | def com_position(
function jacobian (line 238) | def jacobian(
function velocity (line 338) | def velocity(
function jacobian_derivative (line 388) | def jacobian_derivative(
function bias_acceleration (line 433) | def bias_acceleration(
FILE: src/jaxsim/api/model.py
class IntegratorType (line 39) | class IntegratorType(enum.IntEnum):
class JaxSimModel (line 48) | class JaxSimModel(JaxsimDataclass):
method description (line 94) | def description(self) -> ModelDescription:
method __eq__ (line 100) | def __eq__(self, other: JaxSimModel) -> bool:
method __hash__ (line 115) | def __hash__(self) -> int:
method build_from_model_description (line 130) | def build_from_model_description(
method build (line 227) | def build(
method compute_hw_link_metadata (line 333) | def compute_hw_link_metadata(
method export_updated_model (line 631) | def export_updated_model(self) -> str:
method name (line 951) | def name(self) -> str:
method number_of_links (line 961) | def number_of_links(self) -> int:
method number_of_joints (line 974) | def number_of_joints(self) -> int:
method number_of_frames (line 984) | def number_of_frames(self) -> int:
method floating_base (line 999) | def floating_base(self) -> bool:
method base_link (line 1009) | def base_link(self) -> str:
method dofs (line 1026) | def dofs(self) -> int:
method joint_names (line 1040) | def joint_names(self) -> tuple[str, ...]:
method link_names (line 1054) | def link_names(self) -> tuple[str, ...]:
method frame_names (line 1068) | def frame_names(self) -> tuple[str, ...]:
function reduce (line 1084) | def reduce(
function total_mass (line 1165) | def total_mass(model: JaxSimModel) -> jtp.Float:
function link_spatial_inertia_matrices (line 1181) | def link_spatial_inertia_matrices(model: JaxSimModel) -> jtp.Array:
function _adjoint_from_rotation_translation (line 1202) | def _adjoint_from_rotation_translation(
function _inverse_adjoint_from_rotation_translation (line 1217) | def _inverse_adjoint_from_rotation_translation(
function _apply_input_representation_to_jacobian (line 1233) | def _apply_input_representation_to_jacobian(
function _apply_input_representation_derivative_to_jacobian (line 1245) | def _apply_input_representation_derivative_to_jacobian(
function _link_jacobian_support_mask (line 1260) | def _link_jacobian_support_mask(
function _body_input_transform (line 1275) | def _body_input_transform(
function _link_output_adjoint_from_body (line 1309) | def _link_output_adjoint_from_body(
function generalized_free_floating_jacobian (line 1346) | def generalized_free_floating_jacobian(
function generalized_free_floating_jacobian_derivative (line 1400) | def generalized_free_floating_jacobian_derivative(
function forward_dynamics (line 1506) | def forward_dynamics(
function forward_dynamics_aba (line 1545) | def forward_dynamics_aba(
function forward_dynamics_crb (line 1696) | def forward_dynamics_crb(
function forward_kinematics (line 1788) | def forward_kinematics(
function _transform_M_block (line 1835) | def _transform_M_block(M_body: jtp.Matrix, X: jtp.Matrix) -> jtp.Matrix:
function free_floating_mass_matrix (line 1864) | def free_floating_mass_matrix(
function free_floating_mass_matrix_inverse (line 1903) | def free_floating_mass_matrix_inverse(
function free_floating_coriolis_matrix (line 1940) | def free_floating_coriolis_matrix(
function inverse_dynamics (line 2054) | def inverse_dynamics(
function free_floating_gravity_forces (line 2206) | def free_floating_gravity_forces(
function free_floating_bias_forces (line 2243) | def free_floating_bias_forces(
function locked_spatial_inertia (line 2295) | def locked_spatial_inertia(
function total_momentum (line 2314) | def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) ->...
function total_momentum_jacobian (line 2333) | def total_momentum_jacobian(
function average_velocity (line 2401) | def average_velocity(model: JaxSimModel, data: js.data.JaxSimModelData) ...
function average_velocity_jacobian (line 2421) | def average_velocity_jacobian(
function link_bias_accelerations (line 2486) | def link_bias_accelerations(
function joint_transforms (line 2697) | def joint_transforms(
function mechanical_energy (line 2727) | def mechanical_energy(model: JaxSimModel, data: js.data.JaxSimModelData)...
function kinetic_energy (line 2747) | def kinetic_energy(model: JaxSimModel, data: js.data.JaxSimModelData) ->...
function potential_energy (line 2769) | def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) ...
function update_hw_parameters (line 2793) | def update_hw_parameters(
function step (line 3039) | def step(
FILE: src/jaxsim/api/ode.py
function system_acceleration (line 16) | def system_acceleration(
function system_position_dynamics (line 136) | def system_position_dynamics(
function system_dynamics (line 176) | def system_dynamics(
FILE: src/jaxsim/api/references.py
class JaxSimModelReferences (line 23) | class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
method zero (line 36) | def zero(
method build (line 60) | def build(
method valid (line 134) | def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
method link_forces (line 168) | def link_forces(
method joint_force_references (line 250) | def joint_force_references(
method set_joint_force_references (line 306) | def set_joint_force_references(
method apply_link_forces (line 351) | def apply_link_forces(
method apply_frame_forces (line 451) | def apply_frame_forces(
FILE: src/jaxsim/exceptions.py
function raise_if (line 6) | def raise_if(
function raise_runtime_error_if (line 63) | def raise_runtime_error_if(
function raise_value_error_if (line 73) | def raise_value_error_if(
FILE: src/jaxsim/logging.py
class JaxSimWarning (line 10) | class JaxSimWarning(UserWarning):
function pretty_jaxsim_warning (line 18) | def pretty_jaxsim_warning(message, category, filename, lineno, file=None...
function jaxsim_warn (line 39) | def jaxsim_warn(msg):
class LoggingLevel (line 46) | class LoggingLevel(enum.IntEnum):
function _logger (line 55) | def _logger() -> logging.Logger:
function set_logging_level (line 59) | def set_logging_level(level: int | LoggingLevel = LoggingLevel.WARNING):
function get_logging_level (line 66) | def get_logging_level() -> LoggingLevel:
function configure (line 71) | def configure(level: LoggingLevel = LoggingLevel.WARNING) -> None:
function debug (line 86) | def debug(msg: str = "") -> None:
function info (line 90) | def info(msg: str = "") -> None:
function warning (line 94) | def warning(msg: str = "") -> None:
function error (line 98) | def error(msg: str = "") -> None:
function critical (line 102) | def critical(msg: str = "") -> None:
function exception (line 106) | def exception(msg: str = "") -> None:
FILE: src/jaxsim/math/adjoint.py
class Adjoint (line 9) | class Adjoint:
method from_quaternion_and_translation (line 15) | def from_quaternion_and_translation(
method from_transform (line 46) | def from_transform(transform: jtp.MatrixLike, inverse: bool = False) -...
method from_rotation_and_translation (line 67) | def from_rotation_and_translation(
method to_transform (line 110) | def to_transform(adjoint: jtp.Matrix) -> jtp.Matrix:
method inverse (line 136) | def inverse(adjoint: jtp.Matrix) -> jtp.Matrix:
FILE: src/jaxsim/math/cross.py
class Cross (line 8) | class Cross:
method vx (line 14) | def vx(velocity_sixd: jtp.Vector) -> jtp.Matrix:
method vx_star (line 44) | def vx_star(velocity_sixd: jtp.Vector) -> jtp.Matrix:
FILE: src/jaxsim/math/inertia.py
class Inertia (line 8) | class Inertia:
method to_sixd (line 14) | def to_sixd(mass: jtp.Float, com: jtp.Vector, I: jtp.Matrix) -> jtp.Ma...
method to_params (line 44) | def to_params(M: jtp.Matrix) -> tuple[jtp.Float, jtp.Vector, jtp.Matrix]:
FILE: src/jaxsim/math/joint_model.py
class JointModel (line 17) | class JointModel(JaxsimDataclass):
method build (line 46) | def build(description: ModelDescription) -> JointModel:
method parent_H_predecessor (line 115) | def parent_H_predecessor(self, joint_index: jtp.IntLike) -> jtp.Matrix:
method successor_H_child (line 130) | def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix:
function supported_joint_motion (line 147) | def supported_joint_motion(
FILE: src/jaxsim/math/quaternion.py
class Quaternion (line 10) | class Quaternion:
method to_xyzw (line 16) | def to_xyzw(wxyz: jtp.Vector) -> jtp.Vector:
method to_wxyz (line 29) | def to_wxyz(xyzw: jtp.Vector) -> jtp.Vector:
method to_dcm (line 42) | def to_dcm(quaternion: jtp.Vector) -> jtp.Matrix:
method from_dcm (line 55) | def from_dcm(dcm: jtp.Matrix) -> jtp.Vector:
method derivative (line 68) | def derivative(
method integration (line 135) | def integration(
FILE: src/jaxsim/math/rotation.py
class Rotation (line 10) | class Rotation:
method x (line 16) | def x(theta: jtp.Float) -> jtp.Matrix:
method y (line 30) | def y(theta: jtp.Float) -> jtp.Matrix:
method z (line 44) | def z(theta: jtp.Float) -> jtp.Matrix:
method from_axis_angle (line 58) | def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix:
method log_vee (line 87) | def log_vee(R: jnp.ndarray) -> jtp.Vector:
FILE: src/jaxsim/math/skew.py
class Skew (line 6) | class Skew:
method wedge (line 12) | def wedge(vector: jtp.Vector) -> jtp.Matrix:
method vee (line 40) | def vee(matrix: jtp.Matrix) -> jtp.Vector:
FILE: src/jaxsim/math/transform.py
class Transform (line 7) | class Transform:
method from_quaternion_and_translation (line 13) | def from_quaternion_and_translation(
method from_rotation_and_translation (line 52) | def from_rotation_and_translation(
method inverse (line 84) | def inverse(transform: jtp.MatrixLike) -> jtp.Matrix:
FILE: src/jaxsim/math/utils.py
function _make_safe_norm (line 7) | def _make_safe_norm(axis, keepdims):
function safe_norm (line 45) | def safe_norm(array: jtp.ArrayLike, *, axis=None, keepdims: bool = False...
FILE: src/jaxsim/mujoco/loaders.py
function load_rod_model (line 22) | def load_rod_model(
class ModelToMjcf (line 67) | class ModelToMjcf:
method convert (line 73) | def convert(
class RodModelToMjcf (line 120) | class RodModelToMjcf:
method assets_from_rod_model (line 126) | def assets_from_rod_model(
method add_floating_joint (line 167) | def add_floating_joint(
method convert (line 223) | def convert(
class UrdfToMjcf (line 613) | class UrdfToMjcf:
method convert (line 619) | def convert(
class SdfToMjcf (line 661) | class SdfToMjcf:
method convert (line 667) | def convert(
FILE: src/jaxsim/mujoco/model.py
class MujocoModelHelper (line 19) | class MujocoModelHelper:
method __init__ (line 24) | def __init__(self, model: mj.MjModel, data: mj.MjData | None = None) -...
method build_from_xml (line 43) | def build_from_xml(
method time (line 138) | def time(self) -> float:
method timestep (line 143) | def timestep(self) -> float:
method gravity (line 148) | def gravity(self) -> npt.NDArray:
method is_floating_base (line 157) | def is_floating_base(self) -> bool:
method is_fixed_base (line 169) | def is_fixed_base(self) -> bool:
method base_link (line 174) | def base_link(self) -> str:
method base_position (line 181) | def base_position(self) -> npt.NDArray:
method base_orientation (line 190) | def base_orientation(self, dcm: bool = False) -> npt.NDArray:
method set_base_position (line 203) | def set_base_position(self, position: npt.NDArray) -> None:
method set_base_orientation (line 216) | def set_base_orientation(self, orientation: npt.NDArray, dcm: bool = F...
method number_of_joints (line 256) | def number_of_joints(self) -> int:
method number_of_dofs (line 261) | def number_of_dofs(self) -> int:
method joint_names (line 266) | def joint_names(self) -> list[str]:
method joint_dofs (line 274) | def joint_dofs(self, joint_name: str) -> int:
method joint_position (line 282) | def joint_position(self, joint_name: str) -> npt.NDArray:
method joint_positions (line 290) | def joint_positions(self, joint_names: list[str] | None = None) -> npt...
method set_joint_position (line 299) | def set_joint_position(
method set_joint_positions (line 318) | def set_joint_positions(
method number_of_bodies (line 330) | def number_of_bodies(self) -> int:
method body_names (line 335) | def body_names(self) -> list[str]:
method body_position (line 343) | def body_position(self, body_name: str) -> npt.NDArray:
method body_orientation (line 351) | def body_orientation(self, body_name: str, dcm: bool = False) -> npt.N...
method number_of_geometries (line 365) | def number_of_geometries(self) -> int:
method geometry_names (line 370) | def geometry_names(self) -> list[str]:
method geometry_position (line 378) | def geometry_position(self, geometry_name: str) -> npt.NDArray:
method geometry_orientation (line 386) | def geometry_orientation(
method _mask_qpos (line 406) | def _mask_qpos(self, joint_names: tuple[str, ...]) -> npt.NDArray:
function generate_hfield (line 447) | def generate_hfield(
FILE: src/jaxsim/mujoco/utils.py
function mujoco_data_from_jaxsim (line 14) | def mujoco_data_from_jaxsim(
class MujocoCamera (line 108) | class MujocoCamera:
method build (line 131) | def build(cls, **kwargs) -> MujocoCamera:
method build_from_target_view (line 142) | def build_from_target_view(
method asdict (line 227) | def asdict(self) -> dict[str, str]:
FILE: src/jaxsim/mujoco/visualizer.py
class MujocoVideoRecorder (line 13) | class MujocoVideoRecorder:
method __init__ (line 18) | def __init__(
method visualize_frame (line 68) | def visualize_frame(
method reset (line 152) | def reset(
method render_frame (line 172) | def render_frame(
method record_frame (line 212) | def record_frame(
method write_video (line 222) | def write_video(self, path: pathlib.Path | str, exist_ok: bool = False...
method compute_down_sampling (line 237) | def compute_down_sampling(original_fps: int, target_min_fps: int) -> int:
class MujocoVisualizer (line 261) | class MujocoVisualizer:
method __init__ (line 266) | def __init__(
method sync (line 280) | def sync(
method open_viewer (line 294) | def open_viewer(
method open (line 312) | def open(
method setup_viewer_camera (line 348) | def setup_viewer_camera(
FILE: src/jaxsim/parsers/descriptions/collision.py
class CollidablePoint (line 17) | class CollidablePoint:
method change_link (line 31) | def change_link(
method __hash__ (line 54) | def __hash__(self) -> int:
method __eq__ (line 64) | def __eq__(self, other: CollidablePoint) -> bool:
method __str__ (line 71) | def __str__(self) -> str:
class CollisionShape (line 82) | class CollisionShape(abc.ABC):
method __str__ (line 92) | def __str__(self):
class BoxCollision (line 102) | class BoxCollision(CollisionShape):
method __hash__ (line 112) | def __hash__(self) -> int:
method __eq__ (line 120) | def __eq__(self, other: BoxCollision) -> bool:
class SphereCollision (line 129) | class SphereCollision(CollisionShape):
method __hash__ (line 139) | def __hash__(self) -> int:
method __eq__ (line 147) | def __eq__(self, other: BoxCollision) -> bool:
class MeshCollision (line 156) | class MeshCollision(CollisionShape):
method __hash__ (line 166) | def __hash__(self) -> int:
method __eq__ (line 174) | def __eq__(self, other: MeshCollision) -> bool:
FILE: src/jaxsim/parsers/descriptions/joint.py
class JointType (line 16) | class JointType:
class JointGenericAxis (line 27) | class JointGenericAxis:
method __hash__ (line 35) | def __hash__(self) -> int:
method __eq__ (line 39) | def __eq__(self, other: JointGenericAxis) -> bool:
class JointDescription (line 48) | class JointDescription(JaxsimDataclass):
method __post_init__ (line 90) | def __post_init__(self) -> None:
method __eq__ (line 100) | def __eq__(self, other: JointDescription) -> bool:
method __hash__ (line 107) | def __hash__(self) -> int:
FILE: src/jaxsim/parsers/descriptions/link.py
class LinkDescription (line 16) | class LinkDescription(JaxsimDataclass):
method __hash__ (line 41) | def __hash__(self) -> int:
method __eq__ (line 57) | def __eq__(self, other: LinkDescription) -> bool:
method name_and_index (line 76) | def name_and_index(self) -> str:
method lump_with (line 86) | def lump_with(
FILE: src/jaxsim/parsers/descriptions/model.py
class ModelDescription (line 17) | class ModelDescription(KinematicGraph):
method build_model_from (line 36) | def build_model_from(
method reduce (line 157) | def reduce(self, considered_joints: Sequence[str]) -> ModelDescription:
method update_collision_shape_of_link (line 197) | def update_collision_shape_of_link(self, link_name: str, enabled: bool...
method collision_shape_of_link (line 214) | def collision_shape_of_link(self, link_name: str) -> CollisionShape:
method all_enabled_collidable_points (line 237) | def all_enabled_collidable_points(self) -> list[CollidablePoint]:
method __eq__ (line 254) | def __eq__(self, other: ModelDescription) -> bool:
method __hash__ (line 271) | def __hash__(self) -> int:
FILE: src/jaxsim/parsers/kinematic_graph.py
class RootPose (line 21) | class RootPose:
method __hash__ (line 40) | def __hash__(self) -> int:
method __eq__ (line 51) | def __eq__(self, other: RootPose) -> bool:
class KinematicGraph (line 66) | class KinematicGraph(Sequence[LinkDescription]):
method links_dict (line 99) | def links_dict(self) -> dict[str, LinkDescription]:
method frames_dict (line 106) | def frames_dict(self) -> dict[str, LinkDescription]:
method joints_dict (line 113) | def joints_dict(self) -> dict[str, JointDescription]:
method joints_connection_dict (line 120) | def joints_connection_dict(
method __post_init__ (line 128) | def __post_init__(self) -> None:
method build_from (line 174) | def build_from(
method _create_graph (line 234) | def _create_graph(
method reduce (line 379) | def reduce(self, considered_joints: Sequence[str]) -> KinematicGraph:
method link_names (line 613) | def link_names(self) -> list[str]:
method joint_names (line 622) | def joint_names(self) -> list[str]:
method frame_names (line 631) | def frame_names(self) -> list[str]:
method print_tree (line 641) | def print_tree(self) -> None:
method joints_removed (line 658) | def joints_removed(self) -> list[JointDescription]:
method breadth_first_search (line 669) | def breadth_first_search(
method __iter__ (line 715) | def __iter__(self) -> Iterator[LinkDescription]:
method __reversed__ (line 718) | def __reversed__(self) -> Iterable[LinkDescription]:
method __len__ (line 721) | def __len__(self) -> int:
method __contains__ (line 724) | def __contains__(self, item: str | LinkDescription) -> bool:
method __getitem__ (line 733) | def __getitem__(self, key: int | str) -> LinkDescription:
method count (line 748) | def count(self, value: LinkDescription) -> int:
method index (line 754) | def index(self, value: LinkDescription, start: int = 0, stop: int = -1...
class KinematicGraphTransforms (line 767) | class KinematicGraphTransforms:
method __post_init__ (line 785) | def __post_init__(self) -> None:
method initial_joint_positions (line 793) | def initial_joint_positions(self) -> npt.NDArray:
method initial_joint_positions (line 803) | def initial_joint_positions(
method transform (line 831) | def transform(self, name: str) -> npt.NDArray:
method relative_transform (line 913) | def relative_transform(self, relative_to: str, name: str) -> npt.NDArray:
method pre_H_suc (line 935) | def pre_H_suc(
method find_parent_link_of_frame (line 958) | def find_parent_link_of_frame(self, name: str) -> str:
FILE: src/jaxsim/parsers/rod/meshes.py
function extract_points_vertices (line 7) | def extract_points_vertices(mesh: trimesh.Trimesh) -> np.ndarray:
function extract_points_random_surface_sampling (line 14) | def extract_points_random_surface_sampling(mesh: trimesh.Trimesh, n) -> ...
function extract_points_uniform_surface_sampling (line 29) | def extract_points_uniform_surface_sampling(
function extract_points_select_points_over_axis (line 46) | def extract_points_select_points_over_axis(
function extract_points_aap (line 71) | def extract_points_aap(
FILE: src/jaxsim/parsers/rod/parser.py
class SDFData (line 17) | class SDFData(NamedTuple):
function extract_model_data (line 36) | def extract_model_data(
function build_model_description (line 372) | def build_model_description(
FILE: src/jaxsim/parsers/rod/utils.py
function from_sdf_inertial (line 21) | def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
function joint_to_joint_type (line 69) | def joint_to_joint_type(joint: rod.Joint) -> int:
function create_box_collision (line 102) | def create_box_collision(
function create_sphere_collision (line 158) | def create_sphere_collision(
function create_mesh_collision (line 228) | def create_mesh_collision(
function prepare_mesh_for_parametrization (line 283) | def prepare_mesh_for_parametrization(
FILE: src/jaxsim/rbda/aba.py
function aba (line 12) | def aba(
FILE: src/jaxsim/rbda/aba_parallel.py
function aba_parallel (line 14) | def aba_parallel(
FILE: src/jaxsim/rbda/actuation/common.py
class ActuationParams (line 11) | class ActuationParams(JaxsimDataclass):
FILE: src/jaxsim/rbda/collidable_points.py
function collidable_points_pos_vel (line 9) | def collidable_points_pos_vel(
FILE: src/jaxsim/rbda/contacts/common.py
function compute_penetration_data (line 26) | def compute_penetration_data(
class ContactsParams (line 66) | class ContactsParams(JaxsimDataclass):
method build (line 79) | def build(cls: type[Self], **kwargs) -> Self:
method build_default_from_jaxsim_model (line 88) | def build_default_from_jaxsim_model(
method valid (line 171) | def valid(self, **kwargs) -> jtp.BoolLike:
class ContactModel (line 181) | class ContactModel(JaxsimDataclass):
method build (line 188) | def build(
method compute_contact_forces (line 202) | def compute_contact_forces(
method zero_state_variables (line 225) | def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str...
method _parameters_class (line 246) | def _parameters_class(self) -> type[ContactsParams]:
method update_contact_state (line 265) | def update_contact_state(
method update_velocity_after_impact (line 279) | def update_velocity_after_impact(
FILE: src/jaxsim/rbda/contacts/relaxed_rigid.py
class RelaxedRigidContactsParams (line 32) | class RelaxedRigidContactsParams(common.ContactsParams):
method __hash__ (line 85) | def __hash__(self) -> int:
method __eq__ (line 103) | def __eq__(self, other: RelaxedRigidContactsParams) -> bool:
method build (line 110) | def build(
method valid (line 170) | def valid(self) -> jtp.BoolLike:
class RelaxedRigidContacts (line 187) | class RelaxedRigidContacts(common.ContactModel):
method solver_options (line 198) | def solver_options(self) -> dict[str, Any]:
method build (line 210) | def build(
method update_contact_state (line 252) | def update_contact_state(
method update_velocity_after_impact (line 267) | def update_velocity_after_impact(
method compute_contact_forces (line 284) | def compute_contact_forces(
method _regularizers (line 526) | def _regularizers(
FILE: src/jaxsim/rbda/contacts/rigid.py
class RigidContactsParams (line 26) | class RigidContactsParams(ContactsParams):
method __hash__ (line 44) | def __hash__(self) -> int:
method __eq__ (line 55) | def __eq__(self, other: RigidContactsParams) -> bool:
method build (line 62) | def build(
method valid (line 86) | def valid(self) -> jtp.BoolLike:
class RigidContacts (line 96) | class RigidContacts(ContactModel):
method solver_options (line 111) | def solver_options(self) -> dict[str, Any]:
method build (line 123) | def build(
method compute_impact_velocity (line 177) | def compute_impact_velocity(
method compute_contact_forces (line 224) | def compute_contact_forces(
method update_velocity_after_impact (line 383) | def update_velocity_after_impact(
method update_contact_state (line 445) | def update_contact_state(
function _delassus_matrix (line 462) | def _delassus_matrix(
function _compute_ineq_constraint_matrix (line 476) | def _compute_ineq_constraint_matrix(
function _linear_acceleration_of_collidable_points (line 505) | def _linear_acceleration_of_collidable_points(
function _compute_baumgarte_stabilization_term (line 526) | def _compute_baumgarte_stabilization_term(
FILE: src/jaxsim/rbda/contacts/soft.py
class SoftContactsParams (line 25) | class SoftContactsParams(common.ContactsParams):
method __hash__ (line 48) | def __hash__(self) -> int:
method __eq__ (line 62) | def __eq__(self, other: SoftContactsParams) -> bool:
method build (line 70) | def build(
method valid (line 107) | def valid(self) -> jtp.BoolLike:
class SoftContacts (line 127) | class SoftContacts(common.ContactModel):
method build (line 131) | def build(
method zero_state_variables (line 151) | def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str...
method update_contact_state (line 164) | def update_contact_state(
method update_velocity_after_impact (line 179) | def update_velocity_after_impact(
method hunt_crossley_contact_model (line 197) | def hunt_crossley_contact_model(
method compute_contact_force (line 343) | def compute_contact_force(
method compute_contact_forces (line 392) | def compute_contact_forces(
FILE: src/jaxsim/rbda/crba.py
function crba (line 10) | def crba(
FILE: src/jaxsim/rbda/forward_kinematics.py
function forward_kinematics_model (line 11) | def forward_kinematics_model(
FILE: src/jaxsim/rbda/forward_kinematics_parallel.py
function forward_kinematics_model_parallel (line 13) | def forward_kinematics_model_parallel(
FILE: src/jaxsim/rbda/jacobian.py
function jacobian (line 12) | def jacobian(
function jacobian_full_doubly_left (line 129) | def jacobian_full_doubly_left(
function jacobian_derivative_full_doubly_left (line 222) | def jacobian_derivative_full_doubly_left(
FILE: src/jaxsim/rbda/kinematic_constraints.py
function _compute_constraint_transforms_batched (line 19) | def _compute_constraint_transforms_batched(
function _compute_constraint_jacobians_batched (line 60) | def _compute_constraint_jacobians_batched(
function _compute_constraint_baumgarte_term (line 125) | def _compute_constraint_baumgarte_term(
function compute_constraint_wrenches (line 174) | def compute_constraint_wrenches(
FILE: src/jaxsim/rbda/mass_inverse.py
function mass_inverse (line 8) | def mass_inverse(
FILE: src/jaxsim/rbda/rnea.py
function rnea (line 12) | def rnea(
FILE: src/jaxsim/rbda/utils.py
function process_inputs (line 9) | def process_inputs(
FILE: src/jaxsim/terrain/terrain.py
class Terrain (line 15) | class Terrain(abc.ABC):
method height (line 26) | def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
method normal (line 40) | def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
class FlatTerrain (line 66) | class FlatTerrain(Terrain):
method build (line 74) | def build(height: jtp.FloatLike = 0.0) -> FlatTerrain:
method height (line 87) | def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
method normal (line 101) | def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
method __hash__ (line 115) | def __hash__(self) -> int:
method __eq__ (line 119) | def __eq__(self, other: FlatTerrain) -> bool:
class PlaneTerrain (line 128) | class PlaneTerrain(FlatTerrain):
method build (line 138) | def build(height: jtp.FloatLike = 0.0, *, normal: jtp.VectorLike) -> P...
method normal (line 165) | def normal(
method height (line 181) | def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
method __hash__ (line 211) | def __hash__(self) -> int:
method __eq__ (line 224) | def __eq__(self, other: PlaneTerrain) -> bool:
FILE: src/jaxsim/utils/jaxsim_dataclass.py
class JaxsimDataclass (line 22) | class JaxsimDataclass(abc.ABC):
method editable (line 29) | def editable(self: Self, validate: bool = True) -> Iterator[Self]:
method mutable_context (line 52) | def mutable_context(
method get_leaf_shapes (line 125) | def get_leaf_shapes(tree: jtp.PyTree) -> tuple[tuple[int, ...] | None]:
method get_leaf_dtypes (line 145) | def get_leaf_dtypes(tree: jtp.PyTree) -> tuple:
method get_leaf_weak_types (line 165) | def get_leaf_weak_types(tree: jtp.PyTree) -> tuple[bool, ...]:
method check_compatibility (line 184) | def check_compatibility(*trees: Sequence[Any]) -> None:
method is_mutable (line 235) | def is_mutable(self, validate: bool = False) -> bool:
method mutability (line 252) | def mutability(self) -> Mutability:
method set_mutability (line 262) | def set_mutability(self, mutability: Mutability) -> None:
method mutable (line 274) | def mutable(self: Self, mutable: bool = True, validate: bool = False) ...
method copy (line 296) | def copy(self: Self) -> Self:
method replace (line 313) | def replace(self: Self, validate: bool = True, **kwargs) -> Self:
method flatten (line 336) | def flatten(self) -> jtp.Vector:
method flatten_fn (line 347) | def flatten_fn(cls: type[Self]) -> Callable[[Self], jtp.Vector]:
method unflatten_fn (line 357) | def unflatten_fn(self: Self) -> Callable[[jtp.Vector], Self]:
FILE: src/jaxsim/utils/tracing.py
function tracing (line 8) | def tracing(var: Any) -> bool | jax.Array:
function not_tracing (line 16) | def not_tracing(var: Any) -> bool | jax.Array:
FILE: src/jaxsim/utils/wrappers.py
class HashlessObject (line 16) | class HashlessObject(Generic[T]):
method get (line 27) | def get(self: HashlessObject[T]) -> T:
method __hash__ (line 33) | def __hash__(self) -> int:
method __eq__ (line 37) | def __eq__(self, other: HashlessObject[T]) -> bool:
class CustomHashedObject (line 48) | class CustomHashedObject(Generic[T]):
method get (line 57) | def get(self: CustomHashedObject[T]) -> T:
method __hash__ (line 63) | def __hash__(self) -> int:
method __eq__ (line 67) | def __eq__(self, other: CustomHashedObject[T]) -> bool:
class HashedNumpyArray (line 78) | class HashedNumpyArray:
method get (line 101) | def get(self) -> jax.Array | npt.NDArray:
method __hash__ (line 107) | def __hash__(self) -> int:
method __eq__ (line 113) | def __eq__(self, other: HashedNumpyArray) -> bool:
method hash_of_array (line 128) | def hash_of_array(
FILE: tests/conftest.py
function pytest_addoption (line 19) | def pytest_addoption(parser):
function pytest_generate_tests (line 35) | def pytest_generate_tests(metafunc):
function check_gpu_usage (line 43) | def check_gpu_usage():
function pytest_configure (line 72) | def pytest_configure(config) -> None:
function load_model_from_file (line 90) | def load_model_from_file(file_path: pathlib.Path, is_urdf=False) -> rod....
function prng_key (line 111) | def prng_key() -> jax.Array:
function velocity_representation (line 135) | def velocity_representation(request) -> jaxsim.VelRepr:
function integrator (line 154) | def integrator(request) -> str:
function batch_size (line 166) | def batch_size(request) -> int:
function build_jaxsim_model (line 186) | def build_jaxsim_model(
function jaxsim_model_box (line 208) | def jaxsim_model_box() -> js.model.JaxSimModel:
function jaxsim_model_sphere (line 247) | def jaxsim_model_sphere() -> js.model.JaxSimModel:
function ergocub_model_description_path (line 278) | def ergocub_model_description_path() -> pathlib.Path:
function jaxsim_model_ergocub (line 305) | def jaxsim_model_ergocub(
function jaxsim_model_ergocub_reduced (line 320) | def jaxsim_model_ergocub_reduced(jaxsim_model_ergocub) -> js.model.JaxSi...
function jaxsim_model_ur10 (line 354) | def jaxsim_model_ur10() -> js.model.JaxSimModel:
function jaxsim_model_single_pendulum (line 371) | def jaxsim_model_single_pendulum() -> js.model.JaxSimModel:
function jaxsim_model_garpez (line 480) | def jaxsim_model_garpez() -> js.model.JaxSimModel:
function jaxsim_model_garpez_scaled (line 493) | def jaxsim_model_garpez_scaled(request) -> js.model.JaxSimModel:
function create_scalable_garpez_model (line 516) | def create_scalable_garpez_model(
function create_model_with_missing_collision (line 710) | def create_model_with_missing_collision() -> rod.Model:
function jaxsim_model_missing_collision (line 808) | def jaxsim_model_missing_collision() -> js.model.JaxSimModel:
function jaxsim_model_double_pendulum (line 825) | def jaxsim_model_double_pendulum() -> js.model.JaxSimModel:
function jaxsim_model_cartpole (line 840) | def jaxsim_model_cartpole() -> js.model.JaxSimModel:
function jaxsim_model_4_bar_linkage (line 857) | def jaxsim_model_4_bar_linkage() -> js.model.JaxSimModel:
function get_jaxsim_model_fixture (line 877) | def get_jaxsim_model_fixture(
function jaxsim_models_all (line 923) | def jaxsim_models_all(request) -> pathlib.Path | str:
function jaxsim_models_types (line 940) | def jaxsim_models_types(request) -> pathlib.Path | str:
function jaxsim_models_no_joints (line 963) | def jaxsim_models_no_joints(request) -> pathlib.Path | str:
function jaxsim_models_floating_base (line 979) | def jaxsim_models_floating_base(request) -> pathlib.Path | str:
function jaxsim_models_fixed_base (line 994) | def jaxsim_models_fixed_base(request) -> pathlib.Path | str:
function set_jax_32bit (line 1004) | def set_jax_32bit(monkeypatch):
function jaxsim_model_box_32bit (line 1017) | def jaxsim_model_box_32bit(set_jax_32bit, request) -> js.model.JaxSimModel:
FILE: tests/test_actuation.py
function test_tn_curve (line 11) | def test_tn_curve(jaxsim_model_single_pendulum: js.model.JaxSimModel):
FILE: tests/test_api_com.py
function test_com_properties (line 10) | def test_com_properties(
FILE: tests/test_api_contact.py
function test_contact_kinematics (line 11) | def test_contact_kinematics(
function test_collidable_point_jacobians (line 73) | def test_collidable_point_jacobians(
function test_contact_jacobian_derivative (line 105) | def test_contact_jacobian_derivative(
FILE: tests/test_api_data.py
function test_data_valid (line 14) | def test_data_valid(
function test_data_switch_velocity_representation (line 24) | def test_data_switch_velocity_representation(
function test_data_change_velocity_representation (line 64) | def test_data_change_velocity_representation(
FILE: tests/test_api_frame.py
function test_frame_index (line 15) | def test_frame_index(jaxsim_models_types: js.model.JaxSimModel):
function test_frame_transforms (line 75) | def test_frame_transforms(
function test_frame_jacobians (line 129) | def test_frame_jacobians(
function test_frame_jacobian_derivative (line 182) | def test_frame_jacobian_derivative(
FILE: tests/test_api_joint.py
function test_joint_index (line 9) | def test_joint_index(
FILE: tests/test_api_link.py
function test_link_index (line 15) | def test_link_index(
function test_link_inertial_properties (line 57) | def test_link_inertial_properties(
function test_link_transforms (line 98) | def test_link_transforms(
function test_link_jacobians (line 135) | def test_link_jacobians(
function test_link_bias_acceleration (line 205) | def test_link_bias_acceleration(
function test_link_jacobian_derivative (line 289) | def test_link_jacobian_derivative(
FILE: tests/test_api_model.py
function test_model_creation_and_reduction (line 16) | def test_model_creation_and_reduction(
function test_model_properties (line 234) | def test_model_properties(
function test_model_rbda (line 278) | def test_model_rbda(
function test_model_jacobian (line 335) | def test_model_jacobian(
function test_coriolis_matrix (line 406) | def test_coriolis_matrix(
function test_model_fd_id_consistency (line 495) | def test_model_fd_id_consistency(
function test_aba_vs_aba_parallel (line 580) | def test_aba_vs_aba_parallel(
function test_fk_vs_fk_parallel (line 640) | def test_fk_vs_fk_parallel(
FILE: tests/test_api_model_hw_parametrization.py
function test_update_hw_link_parameters (line 21) | def test_update_hw_link_parameters(jaxsim_model_garpez: js.model.JaxSimM...
function test_model_scaling_against_rod (line 87) | def test_model_scaling_against_rod(
function test_update_hw_parameters_vmap (line 145) | def test_update_hw_parameters_vmap(
function test_export_updated_model (line 213) | def test_export_updated_model(
function test_hw_parameters_optimization (line 335) | def test_hw_parameters_optimization(jaxsim_model_garpez: js.model.JaxSim...
function test_hw_parameters_collision_scaling (line 410) | def test_hw_parameters_collision_scaling(
function test_unsupported_link_cases (line 504) | def test_unsupported_link_cases():
function test_export_continuous_joint_handling (line 692) | def test_export_continuous_joint_handling():
function test_export_model_with_missing_collision (line 775) | def test_export_model_with_missing_collision(
function test_export_mesh_scaling_preserves_nonzero_visual_and_joint_origins (line 870) | def test_export_mesh_scaling_preserves_nonzero_visual_and_joint_origins(
function test_mesh_shape_enum (line 994) | def test_mesh_shape_enum():
function test_mixed_shapes_metadata (line 1000) | def test_mixed_shapes_metadata():
function test_mixed_shapes_scaling (line 1040) | def test_mixed_shapes_scaling():
FILE: tests/test_automatic_differentiation.py
function get_random_data_and_references (line 31) | def get_random_data_and_references(
function test_ad_aba (line 66) | def test_ad_aba(
function test_ad_aba_parallel (line 131) | def test_ad_aba_parallel(
function test_ad_rnea (line 211) | def test_ad_rnea(
function test_ad_crba (line 272) | def test_ad_crba(
function test_ad_fk (line 304) | def test_ad_fk(
function test_ad_jacobian (line 360) | def test_ad_jacobian(
function test_ad_soft_contacts (line 399) | def test_ad_soft_contacts(
function test_ad_integration (line 451) | def test_ad_integration(
function test_ad_safe_norm (line 528) | def test_ad_safe_norm(
function test_ad_hw_parameters (line 572) | def test_ad_hw_parameters(
FILE: tests/test_benchmark.py
function vectorize_data (line 12) | def vectorize_data(model: js.model.JaxSimModel, batch_size: int):
function benchmark_test_function (line 24) | def benchmark_test_function(
function test_forward_dynamics_aba (line 39) | def test_forward_dynamics_aba(
function test_free_floating_bias_forces (line 48) | def test_free_floating_bias_forces(
function test_forward_kinematics (line 59) | def test_forward_kinematics(
function test_free_floating_mass_matrix (line 68) | def test_free_floating_mass_matrix(
function test_free_floating_jacobian (line 79) | def test_free_floating_jacobian(
function test_free_floating_jacobian_derivative (line 90) | def test_free_floating_jacobian_derivative(
function test_soft_contact_model (line 104) | def test_soft_contact_model(
function test_rigid_contact_model (line 117) | def test_rigid_contact_model(
function test_relaxed_rigid_contact_model (line 130) | def test_relaxed_rigid_contact_model(
function test_simulation_step (line 143) | def test_simulation_step(
function test_update_hw_parameters (line 156) | def test_update_hw_parameters(
function test_export_updated_model (line 185) | def test_export_updated_model(
FILE: tests/test_exceptions.py
function test_exceptions_in_jit_functions (line 13) | def test_exceptions_in_jit_functions():
FILE: tests/test_meshes.py
function test_mesh_wrapping_vertex_extraction (line 6) | def test_mesh_wrapping_vertex_extraction():
function test_mesh_wrapping_aap (line 30) | def test_mesh_wrapping_aap():
function test_mesh_wrapping_points_over_axis (line 66) | def test_mesh_wrapping_points_over_axis():
FILE: tests/test_pytree.py
function test_call_jit_compiled_function_passing_different_objects (line 13) | def test_call_jit_compiled_function_passing_different_objects(
FILE: tests/test_simulations.py
function test_box_with_external_forces (line 15) | def test_box_with_external_forces(
function test_box_with_zero_gravity (line 88) | def test_box_with_zero_gravity(
function run_simulation (line 170) | def run_simulation(
function test_simulation_with_soft_contacts (line 194) | def test_simulation_with_soft_contacts(
function test_simulation_with_rigid_contacts (line 245) | def test_simulation_with_rigid_contacts(
function test_simulation_with_relaxed_rigid_contacts (line 295) | def test_simulation_with_relaxed_rigid_contacts(
function test_joint_limits (line 347) | def test_joint_limits(
function test_simulation_with_kinematic_constraints_double_pendulum (line 411) | def test_simulation_with_kinematic_constraints_double_pendulum(
function test_simulation_with_kinematic_constraints_cartpole (line 477) | def test_simulation_with_kinematic_constraints_cartpole(
function test_simulation_with_kinematic_constraints_4_bar_linkage (line 549) | def test_simulation_with_kinematic_constraints_4_bar_linkage(
FILE: tests/test_visualizer.py
function mujoco_camera (line 9) | def mujoco_camera():
function test_urdf_loading (line 22) | def test_urdf_loading(jaxsim_model_single_pendulum, mujoco_camera):
function test_sdf_loading (line 28) | def test_sdf_loading(jaxsim_model_single_pendulum, mujoco_camera):
function test_rod_loading (line 37) | def test_rod_loading(jaxsim_model_single_pendulum, mujoco_camera):
function test_heightmap (line 44) | def test_heightmap(jaxsim_model_single_pendulum, mujoco_camera):
function test_inclined_plane (line 56) | def test_inclined_plane(jaxsim_model_single_pendulum, mujoco_camera):
FILE: tests/utils.py
function assert_allclose (line 14) | def assert_allclose(actual, desired, rtol=1e-7, atol=1e-9, err_msg=""):
function build_kindyncomputations_from_jaxsim_model (line 29) | def build_kindyncomputations_from_jaxsim_model(
function store_jaxsim_data_in_kindyncomputations (line 103) | def store_jaxsim_data_in_kindyncomputations(
class KinDynComputations (line 135) | class KinDynComputations:
method build (line 143) | def build(
method set_robot_state (line 196) | def set_robot_state(
method dofs (line 249) | def dofs(self) -> int:
method joint_names (line 253) | def joint_names(self) -> list[str]:
method link_names (line 258) | def link_names(self) -> list[str]:
method frame_names (line 264) | def frame_names(self) -> list[str]:
method joint_positions (line 271) | def joint_positions(self) -> npt.NDArray:
method joint_velocities (line 280) | def joint_velocities(self) -> npt.NDArray:
method jacobian_frame (line 289) | def jacobian_frame(self, frame_name: str) -> npt.NDArray:
method total_mass (line 301) | def total_mass(self) -> float:
method link_spatial_inertia (line 306) | def link_spatial_inertia(self, link_name: str) -> npt.NDArray:
method link_mass (line 316) | def link_mass(self, link_name: str) -> float:
method floating_base_frame (line 326) | def floating_base_frame(self) -> str:
method frame_transform (line 330) | def frame_transform(self, frame_name: str) -> npt.NDArray:
method frame_relative_transform (line 346) | def frame_relative_transform(
method frame_parent_link_name (line 366) | def frame_parent_link_name(self, frame_name: str) -> str:
method base_velocity (line 373) | def base_velocity(self) -> npt.NDArray:
method frame_velocity (line 382) | def frame_velocity(self, frame_name: str) -> npt.NDArray:
method frame_bias_acc (line 391) | def frame_bias_acc(self, frame_name: str) -> npt.NDArray:
method com_position (line 400) | def com_position(self) -> npt.NDArray:
method com_velocity (line 405) | def com_velocity(self) -> npt.NDArray:
method com_bias_acceleration (line 410) | def com_bias_acceleration(self) -> npt.NDArray:
method mass_matrix (line 414) | def mass_matrix(self) -> npt.NDArray:
method bias_forces (line 423) | def bias_forces(self) -> npt.NDArray:
method gravity_forces (line 437) | def gravity_forces(self) -> npt.NDArray:
method total_momentum (line 451) | def total_momentum(self) -> npt.NDArray:
method centroidal_momentum (line 455) | def centroidal_momentum(self) -> npt.NDArray:
method total_momentum_jacobian (line 459) | def total_momentum_jacobian(self) -> npt.NDArray:
method centroidal_momentum_jacobian (line 468) | def centroidal_momentum_jacobian(self) -> npt.NDArray:
method locked_spatial_inertia (line 477) | def locked_spatial_inertia(self) -> npt.NDArray:
method locked_centroidal_spatial_inertia (line 481) | def locked_centroidal_spatial_inertia(self) -> npt.NDArray:
method average_velocity (line 485) | def average_velocity(self) -> npt.NDArray:
method average_velocity_jacobian (line 489) | def average_velocity_jacobian(self) -> npt.NDArray:
method average_centroidal_velocity (line 498) | def average_centroidal_velocity(self) -> npt.NDArray:
method average_centroidal_velocity_jacobian (line 502) | def average_centroidal_velocity_jacobian(self) -> npt.NDArray:
Condensed preview — 136 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,066K chars).
[
{
"path": ".devcontainer/Dockerfile",
"chars": 643,
"preview": "# syntax=docker/dockerfile:1.4\nFROM mcr.microsoft.com/devcontainers/base:jammy\n\nARG PROJECT_NAME=jaxsim\nARG PIXI_VERSION"
},
{
"path": ".devcontainer/devcontainer.json",
"chars": 1851,
"preview": "// For format details, see https://aka.ms/devcontainer.json. For config options, see the\n// README at: https://github.co"
},
{
"path": ".gitattributes",
"chars": 75,
"preview": "# GitHub syntax highlighting\npixi.lock filter=lfs diff=lfs merge=lfs -text\n"
},
{
"path": ".github/CODEOWNERS",
"chars": 20,
"preview": "* @flferretti\n"
},
{
"path": ".github/dependabot.yml",
"chars": 362,
"preview": "version: 2\nupdates:\n\n # Check for updates to GitHub Actions every month.\n - package-ecosystem: github-actions\n dire"
},
{
"path": ".github/release.yml",
"chars": 114,
"preview": "changelog:\n exclude:\n authors:\n - dependabot[bot]\n - pre-commit-ci[bot]\n - github-actions[bot]\n"
},
{
"path": ".github/workflows/ci_cd.yml",
"chars": 3516,
"preview": "name: Python CI/CD\n\non:\n workflow_dispatch:\n push:\n pull_request:\n release:\n types:\n - published\n schedule:"
},
{
"path": ".github/workflows/gpu_benchmark.yml",
"chars": 3337,
"preview": "name: GPU Benchmarks\n\non:\n push:\n branches:\n - main\n pull_request:\n types: [opened, reopened, synchronize]\n"
},
{
"path": ".github/workflows/pixi.yml",
"chars": 1598,
"preview": "name: Pixi\n\npermissions:\n contents: write\n pull-requests: write\n\non:\n workflow_dispatch:\n schedule:\n # Execute at"
},
{
"path": ".github/workflows/read_the_docs.yml",
"chars": 305,
"preview": "name: Read the Docs PR\non:\n pull_request_target:\n types:\n - opened\n\npermissions:\n pull-requests: write\n\njobs:\n"
},
{
"path": ".gitignore",
"chars": 2043,
"preview": "# IDEs\n.idea*\n.vscode/\n\n# Matlab\n*.m~\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C ex"
},
{
"path": ".pre-commit-config.yaml",
"chars": 1205,
"preview": "ci:\n autofix_prs: false\n autoupdate_schedule: quarterly\n submodules: false\n\ndefault_language_version:\n python: pytho"
},
{
"path": ".readthedocs.yaml",
"chars": 224,
"preview": "version: \"2\"\n\nbuild:\n os: ubuntu-24.04\n tools:\n python: \"mambaforge-23.11\"\n\nconda:\n environment: environment.yml\n\n"
},
{
"path": "CITATION.cff",
"chars": 1456,
"preview": "cff-version: 1.2.0\ntitle: JaxSim\nmessage: \"If you use this software, please cite the paper.\"\ntype: software\nauthors:\n -"
},
{
"path": "CONTRIBUTING.md",
"chars": 2767,
"preview": "# Contributing to JAXsim :rocket:\n\nHello Contributor,\n\nWe're thrilled that you're considering contributing to JAXsim!\nHe"
},
{
"path": "LICENSE",
"chars": 1546,
"preview": "BSD 3-Clause License\n\nCopyright (c) 2022, Artificial and Mechanical Intelligence\nAll rights reserved.\n\nRedistribution an"
},
{
"path": "README.md",
"chars": 11298,
"preview": "# JaxSim\n\n**JaxSim** is a **differentiable physics engine** built with JAX, tailored for co-design and robotic learning "
},
{
"path": "docs/Makefile",
"chars": 506,
"preview": "SPHINXOPTS ?=\nSPHINXBUILD ?= sphinx-build\nSOURCEDIR = .\nBUILDDIR = _build\nSPHINXPROJ = JAXsim\n\n# Put it"
},
{
"path": "docs/conf.py",
"chars": 3159,
"preview": "# Configuration file for the Sphinx documentation builder.\nimport os\nimport sys\n\nif os.environ.get(\"READTHEDOCS\"):\n c"
},
{
"path": "docs/examples.rst",
"chars": 1497,
"preview": ".. _collections:\n\nExample Notebooks\n=================\n\n.. toctree::\n :glob:\n :hidden:\n :maxdepth: 1\n\n _colle"
},
{
"path": "docs/guide/configuration.rst",
"chars": 2319,
"preview": "Configuration\n=============\n\nJaxSim utilizes environment variables for application configuration. Below is a detailed ov"
},
{
"path": "docs/guide/install.rst",
"chars": 936,
"preview": "Installation\n============\n\n.. _installation:\n\nPrerequisites\n-------------\n\nJAXsim requires Python 3.11 or later.\n\nBasic "
},
{
"path": "docs/index.rst",
"chars": 4391,
"preview": "JAXsim\n#######\n\nA scalable physics engine and multibody dynamics library implemented with JAX. With JIT batteries 🔋\n\n.. "
},
{
"path": "docs/make.bat",
"chars": 780,
"preview": "@ECHO OFF\r\n\r\npushd %~dp0\r\n\r\nREM\r\n\r\nif \"%SPHINXBUILD%\" == \"\" (\r\n\tset SPHINXBUILD=sphinx-build\r\n)\r\nset SOURCEDIR=.\r\nset BU"
},
{
"path": "docs/modules/api.rst",
"chars": 1381,
"preview": "Functional API\n==============\n\n.. currentmodule:: jaxsim.api\n\n.. autosummary::\n :toctree: _autosummary\n\n model\n "
},
{
"path": "docs/modules/math.rst",
"chars": 467,
"preview": "Math\n====\n\n.. currentmodule:: jaxsim.math\n\n.. automodule:: jaxsim.math.adjoint\n :members:\n :undoc-members:\n\n.. aut"
},
{
"path": "docs/modules/mujoco.rst",
"chars": 543,
"preview": "MuJoCo Visualizer\n==================\n\nJAXsim provides a simple interface with MuJoCo's visualizer. The visualizer is\na s"
},
{
"path": "docs/modules/parsers.rst",
"chars": 279,
"preview": "Parsers\n=======\n\n.. automodule:: jaxsim.parsers.descriptions.collision\n :members:\n\n.. automodule:: jaxsim.parsers.des"
},
{
"path": "docs/modules/rbda.rst",
"chars": 815,
"preview": "Rigid Body Dynamics Algorithms\n==============================\n\nThis module provides a set of algorithms for rigid body d"
},
{
"path": "docs/modules/typing.rst",
"chars": 209,
"preview": "Typing\n======\n\n.. currentmodule:: jaxsim.typing\n\n.. autosummary::\n PyTree\n Matrix\n Bool\n Int\n Float\n V"
},
{
"path": "docs/modules/utils.rst",
"chars": 163,
"preview": "Utils\n=====\n\n.. automodule:: jaxsim.utils\n :members:\n :inherited-members:\n\n.. autoclass:: jaxsim.utils.JaxsimDatac"
},
{
"path": "environment.yml",
"chars": 1341,
"preview": "name: jaxsim\nchannels:\n - conda-forge\ndependencies:\n # ===========================\n # Dependencies from setup.cfg\n #"
},
{
"path": "examples/.gitattributes",
"chars": 62,
"preview": "# GitHub syntax highlighting\npixi.lock linguist-language=YAML\n"
},
{
"path": "examples/.gitignore",
"chars": 26,
"preview": "# pixi environments\n.pixi\n"
},
{
"path": "examples/README.md",
"chars": 2594,
"preview": "# JaxSim Examples\n\nThis folder contains Jupyter notebooks that demonstrate the practical usage of JaxSim.\n\n## Featured e"
},
{
"path": "examples/assets/build_cartpole_urdf.py",
"chars": 4852,
"preview": "import os\n\nif \"ROD_LOGGING_LEVEL\" not in os.environ:\n os.environ[\"ROD_LOGGING_LEVEL\"] = \"WARNING\"\n\nimport numpy as np"
},
{
"path": "examples/assets/cartpole.urdf",
"chars": 3433,
"preview": "<?xml version=\"1.0\" encoding=\"utf-8\"?>\n<robot name=\"cartpole\">\n <link name=\"world\"/>\n <link name=\"rail\">\n <"
},
{
"path": "examples/jaxsim_as_multibody_dynamics_library.ipynb",
"chars": 34325,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"DpLq0-lltwZ1\"\n },\n \"source\": [\n \"# `Jax"
},
{
"path": "examples/jaxsim_as_physics_engine.ipynb",
"chars": 8402,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"H-WgcgGQaTG7\"\n },\n \"source\": [\n \"# JaxS"
},
{
"path": "examples/jaxsim_as_physics_engine_advanced.ipynb",
"chars": 14964,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"H-WgcgGQaTG7\"\n },\n \"source\": [\n \"# JaxS"
},
{
"path": "examples/jaxsim_for_robot_controllers.ipynb",
"chars": 15935,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"EhPy6FgiZH4d\"\n },\n \"source\": [\n \"# JaxS"
},
{
"path": "pyproject.toml",
"chars": 6924,
"preview": "[project]\nname = \"jaxsim\"\ndynamic = [\"version\"]\nrequires-python = \">= 3.10\"\ndescription = \"A differentiable physics engi"
},
{
"path": "src/jaxsim/__init__.py",
"chars": 3628,
"preview": "from . import logging\nfrom ._version import __version__\n\n\n# Follow upstream development in https://github.com/google/jax"
},
{
"path": "src/jaxsim/api/__init__.py",
"chars": 234,
"preview": "from . import common # isort:skip\nfrom . import model, data # isort:skip\nfrom . import (\n actuation_model,\n com,"
},
{
"path": "src/jaxsim/api/actuation_model.py",
"chars": 3789,
"preview": "import jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\n\n\ndef compute_resultant_torques(\n model:"
},
{
"path": "src/jaxsim/api/com.py",
"chars": 13656,
"preview": "import jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.math\nimport jaxsim.typing as jtp\n\nfrom .common"
},
{
"path": "src/jaxsim/api/common.py",
"chars": 6942,
"preview": "import abc\nimport contextlib\nimport dataclasses\nimport enum\nimport functools\nfrom collections.abc import Callable, Itera"
},
{
"path": "src/jaxsim/api/contact.py",
"chars": 20584,
"preview": "from __future__ import annotations\n\nimport functools\n\nimport jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport"
},
{
"path": "src/jaxsim/api/data.py",
"chars": 23339,
"preview": "from __future__ import annotations\n\nimport dataclasses\nimport functools\nfrom collections.abc import Sequence\n\ntry:\n f"
},
{
"path": "src/jaxsim/api/frame.py",
"chars": 11841,
"preview": "import functools\nfrom collections.abc import Sequence\n\nimport jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimpor"
},
{
"path": "src/jaxsim/api/integrators.py",
"chars": 8882,
"preview": "import dataclasses\nfrom collections.abc import Callable\n\nimport jax\nimport jax.numpy as jnp\n\nimport jaxsim\nimport jaxsim"
},
{
"path": "src/jaxsim/api/joint.py",
"chars": 7434,
"preview": "import functools\nfrom collections.abc import Sequence\n\nimport jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimpor"
},
{
"path": "src/jaxsim/api/kin_dyn_parameters.py",
"chars": 56937,
"preview": "from __future__ import annotations\n\nimport dataclasses\nfrom itertools import starmap\nfrom typing import ClassVar\n\nimport"
},
{
"path": "src/jaxsim/api/link.py",
"chars": 12770,
"preview": "import functools\nfrom collections.abc import Sequence\n\nimport jax\nimport jax.numpy as jnp\nimport jax.scipy.linalg\nimport"
},
{
"path": "src/jaxsim/api/model.py",
"chars": 108673,
"preview": "from __future__ import annotations\n\nimport copy\nimport dataclasses\nimport enum\nimport functools\nimport pathlib\nfrom coll"
},
{
"path": "src/jaxsim/api/ode.py",
"chars": 7639,
"preview": "import jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim.math import Quaterni"
},
{
"path": "src/jaxsim/api/references.py",
"chars": 20267,
"preview": "from __future__ import annotations\n\nimport functools\n\nimport jax\nimport jax.numpy as jnp\nimport jax_dataclasses\n\nimport "
},
{
"path": "src/jaxsim/exceptions.py",
"chars": 2641,
"preview": "import os\n\nimport jax\n\n\ndef raise_if(\n condition: bool | jax.Array, exception: type, msg: str, *args, **kwargs\n) -> N"
},
{
"path": "src/jaxsim/logging.py",
"chars": 2567,
"preview": "import enum\nimport inspect\nimport logging\nimport os\nimport warnings\n\nimport coloredlogs\n\n\nclass JaxSimWarning(UserWarnin"
},
{
"path": "src/jaxsim/math/__init__.py",
"chars": 382,
"preview": "from .adjoint import Adjoint\nfrom .cross import Cross\nfrom .inertia import Inertia\nfrom .quaternion import Quaternion\nfr"
},
{
"path": "src/jaxsim/math/adjoint.py",
"chars": 4667,
"preview": "import jax.numpy as jnp\nimport jaxlie\n\nimport jaxsim.typing as jtp\n\nfrom .skew import Skew\n\n\nclass Adjoint:\n \"\"\"\n "
},
{
"path": "src/jaxsim/math/cross.py",
"chars": 1492,
"preview": "import jax.numpy as jnp\n\nimport jaxsim.typing as jtp\n\nfrom .skew import Skew\n\n\nclass Cross:\n \"\"\"\n A utility class "
},
{
"path": "src/jaxsim/math/inertia.py",
"chars": 1577,
"preview": "import jax.numpy as jnp\n\nimport jaxsim.typing as jtp\n\nfrom .skew import Skew\n\n\nclass Inertia:\n \"\"\"\n A utility clas"
},
{
"path": "src/jaxsim/math/joint_model.py",
"chars": 6980,
"preview": "from __future__ import annotations\n\nimport jax\nimport jax.numpy as jnp\nimport jax_dataclasses\nimport jaxlie\nfrom jax_dat"
},
{
"path": "src/jaxsim/math/quaternion.py",
"chars": 4531,
"preview": "import jax.lax\nimport jax.numpy as jnp\nimport jaxlie\n\nimport jaxsim.typing as jtp\n\nfrom .utils import safe_norm\n\n\nclass "
},
{
"path": "src/jaxsim/math/rotation.py",
"chars": 2220,
"preview": "import jax.numpy as jnp\nimport jaxlie\n\nimport jaxsim.typing as jtp\n\nfrom .skew import Skew\nfrom .utils import safe_norm\n"
},
{
"path": "src/jaxsim/math/skew.py",
"chars": 1402,
"preview": "import jax.numpy as jnp\n\nimport jaxsim.typing as jtp\n\n\nclass Skew:\n \"\"\"\n A utility class for skew-symmetric matrix"
},
{
"path": "src/jaxsim/math/transform.py",
"chars": 3192,
"preview": "import jax.numpy as jnp\nimport jaxlie\n\nimport jaxsim.typing as jtp\n\n\nclass Transform:\n \"\"\"\n A utility class for tr"
},
{
"path": "src/jaxsim/math/utils.py",
"chars": 1899,
"preview": "import jax\nimport jax.numpy as jnp\n\nimport jaxsim.typing as jtp\n\n\ndef _make_safe_norm(axis, keepdims):\n @jax.custom_j"
},
{
"path": "src/jaxsim/mujoco/__init__.py",
"chars": 228,
"preview": "from .loaders import ModelToMjcf, RodModelToMjcf, SdfToMjcf, UrdfToMjcf\nfrom .model import MujocoModelHelper\nfrom .utils"
},
{
"path": "src/jaxsim/mujoco/__main__.py",
"chars": 4886,
"preview": "import argparse\nimport pathlib\nimport sys\nimport time\n\nimport numpy as np\n\nfrom . import ModelToMjcf, MujocoModelHelper,"
},
{
"path": "src/jaxsim/mujoco/loaders.py",
"chars": 24214,
"preview": "import pathlib\nimport tempfile\nimport warnings\nfrom collections.abc import Sequence\nfrom typing import Any\n\nimport jaxli"
},
{
"path": "src/jaxsim/mujoco/model.py",
"chars": 16484,
"preview": "from __future__ import annotations\n\nimport functools\nimport pathlib\nfrom collections.abc import Callable, Sequence\nfrom "
},
{
"path": "src/jaxsim/mujoco/utils.py",
"chars": 8823,
"preview": "from __future__ import annotations\n\nimport dataclasses\nfrom collections.abc import Sequence\n\nimport mujoco as mj\nimport "
},
{
"path": "src/jaxsim/mujoco/visualizer.py",
"chars": 12168,
"preview": "import contextlib\nimport pathlib\nfrom collections.abc import Iterator, Sequence\n\nimport mediapy as media\nimport mujoco a"
},
{
"path": "src/jaxsim/parsers/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/jaxsim/parsers/descriptions/__init__.py",
"chars": 261,
"preview": "from .collision import (\n BoxCollision,\n CollidablePoint,\n CollisionShape,\n MeshCollision,\n SphereCollisi"
},
{
"path": "src/jaxsim/parsers/descriptions/collision.py",
"chars": 4577,
"preview": "from __future__ import annotations\n\nimport abc\nimport dataclasses\n\nimport jax.numpy as jnp\nimport numpy as np\nimport num"
},
{
"path": "src/jaxsim/parsers/descriptions/joint.py",
"chars": 4232,
"preview": "from __future__ import annotations\n\nimport dataclasses\nfrom typing import ClassVar\n\nimport jax_dataclasses\nimport numpy "
},
{
"path": "src/jaxsim/parsers/descriptions/link.py",
"chars": 3495,
"preview": "from __future__ import annotations\n\nimport dataclasses\n\nimport jax.numpy as jnp\nimport jax_dataclasses\nimport numpy as n"
},
{
"path": "src/jaxsim/parsers/descriptions/model.py",
"chars": 10152,
"preview": "from __future__ import annotations\n\nimport dataclasses\nimport itertools\nfrom collections.abc import Sequence\n\nfrom jaxsi"
},
{
"path": "src/jaxsim/parsers/kinematic_graph.py",
"chars": 35734,
"preview": "from __future__ import annotations\n\nimport copy\nimport dataclasses\nimport functools\nfrom collections.abc import Callable"
},
{
"path": "src/jaxsim/parsers/rod/__init__.py",
"chars": 92,
"preview": "from . import parser, utils\nfrom .parser import build_model_description, extract_model_data\n"
},
{
"path": "src/jaxsim/parsers/rod/meshes.py",
"chars": 2842,
"preview": "import numpy as np\nimport trimesh\n\nVALID_AXIS = {\"x\": 0, \"y\": 1, \"z\": 2}\n\n\ndef extract_points_vertices(mesh: trimesh.Tri"
},
{
"path": "src/jaxsim/parsers/rod/parser.py",
"chars": 15037,
"preview": "import dataclasses\nimport os\nimport pathlib\nfrom typing import NamedTuple\n\nimport jax.numpy as jnp\nimport numpy as np\nim"
},
{
"path": "src/jaxsim/parsers/rod/utils.py",
"chars": 10238,
"preview": "import os\nimport pathlib\nfrom collections.abc import Callable\nfrom typing import TypeVar\n\nimport numpy as np\nimport nump"
},
{
"path": "src/jaxsim/rbda/__init__.py",
"chars": 544,
"preview": "from . import actuation, contacts\nfrom .aba import aba\nfrom .aba_parallel import aba_parallel\nfrom .collidable_points im"
},
{
"path": "src/jaxsim/rbda/aba.py",
"chars": 8861,
"preview": "import jax\nimport jax.numpy as jnp\nimport jaxlie\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim.math i"
},
{
"path": "src/jaxsim/rbda/aba_parallel.py",
"chars": 10234,
"preview": "import math\n\nimport jax\nimport jax.numpy as jnp\nimport jaxlie\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom "
},
{
"path": "src/jaxsim/rbda/actuation/__init__.py",
"chars": 36,
"preview": "from .common import ActuationParams\n"
},
{
"path": "src/jaxsim/rbda/actuation/common.py",
"chars": 565,
"preview": "import dataclasses\n\nimport jax_dataclasses\nfrom jax_dataclasses import Static\n\nimport jaxsim.typing as jtp\nfrom jaxsim.u"
},
{
"path": "src/jaxsim/rbda/collidable_points.py",
"chars": 2070,
"preview": "import jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim.math import Skew\n\n\nd"
},
{
"path": "src/jaxsim/rbda/contacts/__init__.py",
"chars": 371,
"preview": "from . import relaxed_rigid, rigid, soft\nfrom .common import ContactModel, ContactsParams\nfrom .relaxed_rigid import Rel"
},
{
"path": "src/jaxsim/rbda/contacts/common.py",
"chars": 9185,
"preview": "from __future__ import annotations\n\nimport abc\nimport functools\n\nimport jax\nimport jax.numpy as jnp\n\nimport jaxsim.api a"
},
{
"path": "src/jaxsim/rbda/contacts/relaxed_rigid.py",
"chars": 21763,
"preview": "from __future__ import annotations\n\nimport dataclasses\nfrom collections.abc import Callable\nfrom typing import Any\n\nimpo"
},
{
"path": "src/jaxsim/rbda/contacts/rigid.py",
"chars": 17554,
"preview": "from __future__ import annotations\n\nimport dataclasses\nfrom typing import Any\n\nimport jax\nimport jax.numpy as jnp\nimport"
},
{
"path": "src/jaxsim/rbda/contacts/soft.py",
"chars": 15005,
"preview": "from __future__ import annotations\n\nimport dataclasses\nimport functools\n\nimport jax\nimport jax.numpy as jnp\nimport jax_d"
},
{
"path": "src/jaxsim/rbda/crba.py",
"chars": 5083,
"preview": "import jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\n\nfrom . import utils\n\n\ndef crba("
},
{
"path": "src/jaxsim/rbda/forward_kinematics.py",
"chars": 3530,
"preview": "import jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim.math import Adjoint\n"
},
{
"path": "src/jaxsim/rbda/forward_kinematics_parallel.py",
"chars": 3430,
"preview": "import math\n\nimport jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim.math im"
},
{
"path": "src/jaxsim/rbda/jacobian.py",
"chars": 11058,
"preview": "import jax\nimport jax.numpy as jnp\nimport numpy as np\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim.m"
},
{
"path": "src/jaxsim/rbda/kinematic_constraints.py",
"chars": 12130,
"preview": "from __future__ import annotations\n\nimport jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as "
},
{
"path": "src/jaxsim/rbda/mass_inverse.py",
"chars": 5598,
"preview": "import jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\n\n\ndef mass_inverse(\n model: j"
},
{
"path": "src/jaxsim/rbda/rnea.py",
"chars": 7660,
"preview": "import jax\nimport jax.numpy as jnp\nimport jaxlie\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim.math i"
},
{
"path": "src/jaxsim/rbda/utils.py",
"chars": 5552,
"preview": "import jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim import exceptions\nfrom jaxsim.m"
},
{
"path": "src/jaxsim/terrain/__init__.py",
"chars": 78,
"preview": "from . import terrain\nfrom .terrain import FlatTerrain, PlaneTerrain, Terrain\n"
},
{
"path": "src/jaxsim/terrain/terrain.py",
"chars": 6725,
"preview": "from __future__ import annotations\n\nimport abc\nimport dataclasses\n\nimport jax.numpy as jnp\nimport jax_dataclasses\nimport"
},
{
"path": "src/jaxsim/typing.py",
"chars": 761,
"preview": "from collections.abc import Hashable\nfrom typing import Any, TypeVar\n\nimport jax\n\n# =========\n# JAX types\n# =========\n\nA"
},
{
"path": "src/jaxsim/utils/__init__.py",
"chars": 215,
"preview": "from jax_dataclasses._copy_and_mutate import _Mutability as Mutability\n\nfrom .jaxsim_dataclass import JaxsimDataclass\nfr"
},
{
"path": "src/jaxsim/utils/jaxsim_dataclass.py",
"chars": 11373,
"preview": "import abc\nimport contextlib\nimport dataclasses\nimport functools\nfrom collections.abc import Callable, Iterator, Sequenc"
},
{
"path": "src/jaxsim/utils/tracing.py",
"chars": 530,
"preview": "from typing import Any\n\nimport jax._src.core\nimport jax.flatten_util\nimport jax.interpreters.partial_eval\n\n\ndef tracing("
},
{
"path": "src/jaxsim/utils/wrappers.py",
"chars": 4175,
"preview": "from __future__ import annotations\n\nimport dataclasses\nfrom collections.abc import Callable\nfrom typing import Generic, "
},
{
"path": "tests/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "tests/assets/4_bar_opened.urdf",
"chars": 4109,
"preview": "<robot name=\"4_bar_opened\">\n <!-- Link AB -->\n <link name=\"AB\">\n <inertial>\n <mass value=\"1.0\" /"
},
{
"path": "tests/assets/cube.stl",
"chars": 1521,
"preview": "solid model\r\nfacet normal 0.0 0.0 -1.0\r\nouter loop\r\nvertex 20.0 0.0 0.0\r\nvertex 0.0 -20.0 0.0\r\nvertex 0.0 0.0 0.0\r\nendlo"
},
{
"path": "tests/assets/double_pendulum.sdf",
"chars": 5999,
"preview": "<?xml version=\"1.0\"?>\n\n<sdf version=\"1.7\">\n\n <model name=\"double_pendulum\">\n <!-- <pose>0 0 0.2 0 0 0</pose> -"
},
{
"path": "tests/assets/mixed_shapes_robot.urdf",
"chars": 2927,
"preview": "<?xml version=\"1.0\"?>\n<robot name=\"mixed_shapes_robot\">\n\n <!-- Link 1: Box primitive -->\n <link name=\"box_link\">\n <"
},
{
"path": "tests/assets/test_cube.urdf",
"chars": 1076,
"preview": "<?xml version=\"1.0\"?>\n<robot name=\"test_cube\">\n\n <!-- Single cube link with mesh visual -->\n <link name=\"cube_link\">\n "
},
{
"path": "tests/conftest.py",
"chars": 27624,
"preview": "import os\n\nos.environ[\"JAXSIM_ENABLE_EXCEPTIONS\"] = \"1\"\n\nimport pathlib\nimport subprocess\n\nimport jax\nimport numpy as np"
},
{
"path": "tests/test_actuation.py",
"chars": 1464,
"preview": "import jax.numpy as jnp\nfrom numpy.testing import assert_array_less\n\nimport jaxsim.api as js\nimport jaxsim.rbda\nfrom jax"
},
{
"path": "tests/test_api_com.py",
"chars": 2425,
"preview": "import jax\n\nimport jaxsim.api as js\nfrom jaxsim import VelRepr\n\nfrom . import utils\nfrom .utils import assert_allclose\n\n"
},
{
"path": "tests/test_api_contact.py",
"chars": 6609,
"preview": "import jax\nimport jax.numpy as jnp\nimport rod\n\nimport jaxsim.api as js\nfrom jaxsim import VelRepr\n\nfrom .utils import as"
},
{
"path": "tests/test_api_data.py",
"chars": 3620,
"preview": "import jax\nimport jax.numpy as jnp\nimport pytest\nfrom numpy.testing import assert_raises\n\nimport jaxsim.api as js\nfrom j"
},
{
"path": "tests/test_api_frame.py",
"chars": 9014,
"preview": "import jax\nimport jax.numpy as jnp\nimport pytest\nfrom jax.errors import JaxRuntimeError\nfrom numpy.testing import assert"
},
{
"path": "tests/test_api_joint.py",
"chars": 1372,
"preview": "import jax.numpy as jnp\nimport pytest\nfrom jax.errors import JaxRuntimeError\nfrom numpy.testing import assert_array_equa"
},
{
"path": "tests/test_api_link.py",
"chars": 11201,
"preview": "import jax\nimport jax.numpy as jnp\nimport pytest\nfrom jax.errors import JaxRuntimeError\nfrom numpy.testing import assert"
},
{
"path": "tests/test_api_model.py",
"chars": 21619,
"preview": "import pathlib\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport rod\n\nimport jaxsim.api as js\nimport jaxsim."
},
{
"path": "tests/test_api_model_hw_parametrization.py",
"chars": 40086,
"preview": "import pathlib\nimport xml.etree.ElementTree as ET\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport pytest\ni"
},
{
"path": "tests/test_automatic_differentiation.py",
"chars": 18070,
"preview": "import os\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom jax.test_util import check_grads\n\nimport jaxsim.ap"
},
{
"path": "tests/test_benchmark.py",
"chars": 6992,
"preview": "from collections.abc import Callable\n\nimport jax\nimport jax.numpy as jnp\nimport pytest\n\nimport jaxsim\nimport jaxsim.api "
},
{
"path": "tests/test_exceptions.py",
"chars": 2371,
"preview": "import io\nfrom contextlib import redirect_stdout\n\nimport chex\nimport jax\nimport jax.numpy as jnp\nimport pytest\nfrom jax."
},
{
"path": "tests/test_meshes.py",
"chars": 3627,
"preview": "import trimesh\n\nfrom jaxsim.parsers.rod import meshes\n\n\ndef test_mesh_wrapping_vertex_extraction():\n \"\"\"\n Test the"
},
{
"path": "tests/test_pytree.py",
"chars": 2806,
"preview": "import io\nimport pathlib\nfrom contextlib import redirect_stdout\n\nimport chex\nimport jax\nimport jax.numpy as jnp\nimport p"
},
{
"path": "tests/test_simulations.py",
"chars": 19210,
"preview": "import jax\nimport jax.numpy as jnp\nimport numpy as np\nimport pytest\n\nimport jaxsim.api as js\nimport jaxsim.rbda\nimport j"
},
{
"path": "tests/test_visualizer.py",
"chars": 1593,
"preview": "import pytest\nimport rod\n\nfrom jaxsim.mujoco import ModelToMjcf\nfrom jaxsim.mujoco.loaders import MujocoCamera\n\n\n@pytest"
},
{
"path": "tests/utils.py",
"chars": 16236,
"preview": "from __future__ import annotations\n\nimport dataclasses\nimport pathlib\n\nimport idyntree.bindings as idt\nimport numpy as n"
}
]
About this extraction
This page contains the full source code of the ami-iit/jaxsim GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 136 files (988.2 KB), approximately 255.9k tokens, and a symbol index with 670 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.