Repository: ami-iit/jaxsim Branch: main Commit: 62ac1bc2707f Files: 136 Total size: 988.2 KB Directory structure: gitextract_5j4outfv/ ├── .devcontainer/ │ ├── Dockerfile │ └── devcontainer.json ├── .gitattributes ├── .github/ │ ├── CODEOWNERS │ ├── dependabot.yml │ ├── release.yml │ └── workflows/ │ ├── ci_cd.yml │ ├── gpu_benchmark.yml │ ├── pixi.yml │ └── read_the_docs.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CITATION.cff ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs/ │ ├── Makefile │ ├── conf.py │ ├── examples.rst │ ├── guide/ │ │ ├── configuration.rst │ │ └── install.rst │ ├── index.rst │ ├── make.bat │ └── modules/ │ ├── api.rst │ ├── math.rst │ ├── mujoco.rst │ ├── parsers.rst │ ├── rbda.rst │ ├── typing.rst │ └── utils.rst ├── environment.yml ├── examples/ │ ├── .gitattributes │ ├── .gitignore │ ├── README.md │ ├── assets/ │ │ ├── build_cartpole_urdf.py │ │ └── cartpole.urdf │ ├── jaxsim_as_multibody_dynamics_library.ipynb │ ├── jaxsim_as_physics_engine.ipynb │ ├── jaxsim_as_physics_engine_advanced.ipynb │ └── jaxsim_for_robot_controllers.ipynb ├── pyproject.toml ├── src/ │ └── jaxsim/ │ ├── __init__.py │ ├── api/ │ │ ├── __init__.py │ │ ├── actuation_model.py │ │ ├── com.py │ │ ├── common.py │ │ ├── contact.py │ │ ├── data.py │ │ ├── frame.py │ │ ├── integrators.py │ │ ├── joint.py │ │ ├── kin_dyn_parameters.py │ │ ├── link.py │ │ ├── model.py │ │ ├── ode.py │ │ └── references.py │ ├── exceptions.py │ ├── logging.py │ ├── math/ │ │ ├── __init__.py │ │ ├── adjoint.py │ │ ├── cross.py │ │ ├── inertia.py │ │ ├── joint_model.py │ │ ├── quaternion.py │ │ ├── rotation.py │ │ ├── skew.py │ │ ├── transform.py │ │ └── utils.py │ ├── mujoco/ │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── loaders.py │ │ ├── model.py │ │ ├── utils.py │ │ └── visualizer.py │ ├── parsers/ │ │ ├── __init__.py │ │ ├── descriptions/ │ │ │ ├── __init__.py │ │ │ ├── collision.py │ │ │ ├── joint.py │ │ │ ├── link.py │ │ │ └── model.py │ │ ├── kinematic_graph.py │ │ └── rod/ │ │ ├── __init__.py │ │ ├── meshes.py │ │ ├── parser.py │ │ └── utils.py │ ├── rbda/ │ │ ├── __init__.py │ │ ├── aba.py │ │ ├── aba_parallel.py │ │ ├── actuation/ │ │ │ ├── __init__.py │ │ │ └── common.py │ │ ├── collidable_points.py │ │ ├── contacts/ │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── relaxed_rigid.py │ │ │ ├── rigid.py │ │ │ └── soft.py │ │ ├── crba.py │ │ ├── forward_kinematics.py │ │ ├── forward_kinematics_parallel.py │ │ ├── jacobian.py │ │ ├── kinematic_constraints.py │ │ ├── mass_inverse.py │ │ ├── rnea.py │ │ └── utils.py │ ├── terrain/ │ │ ├── __init__.py │ │ └── terrain.py │ ├── typing.py │ └── utils/ │ ├── __init__.py │ ├── jaxsim_dataclass.py │ ├── tracing.py │ └── wrappers.py └── tests/ ├── __init__.py ├── assets/ │ ├── 4_bar_opened.urdf │ ├── cube.stl │ ├── double_pendulum.sdf │ ├── mixed_shapes_robot.urdf │ └── test_cube.urdf ├── conftest.py ├── test_actuation.py ├── test_api_com.py ├── test_api_contact.py ├── test_api_data.py ├── test_api_frame.py ├── test_api_joint.py ├── test_api_link.py ├── test_api_model.py ├── test_api_model_hw_parametrization.py ├── test_automatic_differentiation.py ├── test_benchmark.py ├── test_exceptions.py ├── test_meshes.py ├── test_pytree.py ├── test_simulations.py ├── test_visualizer.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .devcontainer/Dockerfile ================================================ # syntax=docker/dockerfile:1.4 FROM mcr.microsoft.com/devcontainers/base:jammy ARG PROJECT_NAME=jaxsim ARG PIXI_VERSION=v0.35.0 RUN curl -o /usr/local/bin/pixi -SL https://github.com/prefix-dev/pixi/releases/download/${PIXI_VERSION}/pixi-$(uname -m)-unknown-linux-musl \ && chmod +x /usr/local/bin/pixi \ && pixi info # Add LFS repository and install. RUN apt-get update && apt-get install -y curl \ && curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash \ && apt install -y git-lfs USER vscode WORKDIR /home/vscode RUN echo 'eval "$(pixi completion -s bash)"' >> /home/vscode/.bashrc ================================================ FILE: .devcontainer/devcontainer.json ================================================ // For format details, see https://aka.ms/devcontainer.json. For config options, see the // README at: https://github.com/devcontainers/templates/tree/main/src/ubuntu { "name": "Ubuntu", // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile "build": { "context": "..", "dockerfile": "Dockerfile" }, // Features to add to the dev container. More info: https://containers.dev/features. "features": { "ghcr.io/devcontainers/features/docker-in-docker:2": {} }, // Put `.pixi` folder in a mounted volume of a case-insensitive filesystem. "mounts": ["source=${localWorkspaceFolderBasename}-pixi,target=${containerWorkspaceFolder}/.pixi,type=volume"], // Use 'forwardPorts' to make a list of ports inside the container available locally. // "forwardPorts": [], // Use 'postCreateCommand' to run commands after the container is created. "postCreateCommand": "sudo chown vscode .pixi && git lfs pull --include='pixi.lock' && pixi install --environment=test-cpu", // Configure tool-specific properties. // "customizations": {}, // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. // "remoteUser": "root" // VSCode extensions "customizations": { "vscode": { "settings": { "python.pythonPath": "/workspaces/jaxsim/.pixi/envs/test-cpu/bin/python", "python.defaultInterpreterPath": "/workspaces/jaxsim/.pixi/envs/test-cpu/bin/python", "python.terminal.activateEnvironment": true, "python.terminal.activateEnvInCurrentTerminal": true }, "extensions": [ "ms-python.python", "donjayamanne.python-extension-pack", "ms-toolsai.jupyter", "GitHub.codespaces", "GitHub.copilot", "ms-azuretools.vscode-docker", "charliermarsh.ruff" ] } } } ================================================ FILE: .gitattributes ================================================ # GitHub syntax highlighting pixi.lock filter=lfs diff=lfs merge=lfs -text ================================================ FILE: .github/CODEOWNERS ================================================ * @flferretti ================================================ FILE: .github/dependabot.yml ================================================ version: 2 updates: # Check for updates to GitHub Actions every month. - package-ecosystem: github-actions directory: / schedule: interval: monthly # Disable rebasing automatically existing pull requests. rebase-strategy: "disabled" # Group updates to a single PR. groups: dependencies: patterns: - '*' ================================================ FILE: .github/release.yml ================================================ changelog: exclude: authors: - dependabot[bot] - pre-commit-ci[bot] - github-actions[bot] ================================================ FILE: .github/workflows/ci_cd.yml ================================================ name: Python CI/CD on: workflow_dispatch: push: pull_request: release: types: - published schedule: # Execute a nightly build at 2am UTC. - cron: '0 2 * * *' jobs: package: name: Package the project runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 with: fetch-depth: 0 - name: Set up Python uses: actions/setup-python@v6 with: python-version: "3.11" - name: Install Python tools run: pip install build twine - name: Create distributions run: python -m build -o dist/ - name: Inspect dist folder run: ls -lah dist/ - name: Check wheel's abi and platform tags run: test $(find dist/ -name *-none-any.whl | wc -l) -gt 0 - name: Run twine check run: twine check dist/* - name: Upload artifacts uses: actions/upload-artifact@v7 with: path: dist/* name: dist test: name: 'Python${{ matrix.python }}@${{ matrix.os }}' needs: package runs-on: ${{ matrix.os }} env: PYTHONUTF8: "1" strategy: fail-fast: false matrix: os: - ubuntu-latest - macos-latest - windows-latest python: - "3.10" - "3.11" - "3.12" - "3.13" steps: - name: Set up Python uses: actions/setup-python@v6 with: python-version: ${{ matrix.python }} - name: Download Python packages uses: actions/download-artifact@v8 with: path: dist name: dist - name: Install wheel (ubuntu) if: contains(matrix.os, 'ubuntu') shell: bash run: pip install "$(find dist/ -type f -name '*.whl')" - name: Install wheel (macos|windows) if: contains(matrix.os, 'macos') || contains(matrix.os, 'windows') shell: bash run: pip install "$(find dist/ -type f -name '*.whl')" - name: Document installed pip packages shell: bash run: pip list --verbose - name: Import the package run: python -c "import jaxsim" - uses: actions/checkout@v6 with: lfs: true - uses: prefix-dev/setup-pixi@v0.9.5 if: contains(matrix.os, 'ubuntu') with: pixi-version: "latest" frozen: true cache: true cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} - name: Run the Python tests if: | contains(matrix.os, 'ubuntu') && (github.event_name != 'pull_request') run: pixi run --frozen test --numprocesses auto env: # https://github.com/pytest-dev/pytest/issues/7443#issuecomment-656642591 PY_COLORS: "1" JAX_PLATFORM_NAME: cpu publish: name: Publish to PyPI needs: test runs-on: ubuntu-latest permissions: id-token: write steps: - name: Download Python packages uses: actions/download-artifact@v8 with: path: dist name: dist - name: Inspect dist folder run: ls -lah dist/ - name: Publish to PyPI if: | github.repository == 'gbionics/jaxsim' && ((github.event_name == 'push' && github.ref == 'refs/heads/main') || (github.event_name == 'release')) uses: pypa/gh-action-pypi-publish@release/v1 with: skip-existing: true ================================================ FILE: .github/workflows/gpu_benchmark.yml ================================================ name: GPU Benchmarks on: push: branches: - main pull_request: types: [opened, reopened, synchronize] workflow_dispatch: schedule: - cron: "0 0 * * 1" # Run At 00:00 on Monday permissions: pull-requests: write deployments: write contents: write jobs: benchmark: runs-on: self-hosted container: image: ghcr.io/prefix-dev/pixi:0.46.0-noble@sha256:c12bcbe8ba5dfd71867495d3471b95a6993b79cc7de7eafec016f8f59e4e4961 options: --rm --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -e "TERM=xterm-256color" steps: - name: Install Git and Git-LFS run: | apt update && apt install -y git git-lfs - name: Checkout repository uses: actions/checkout@v6 with: lfs: true fetch-depth: 0 - name: Fetch pixi.lock from LFS run: | git config --global safe.directory /__w/jaxsim/jaxsim git lfs checkout pixi.lock - name: Get main branch SHA id: get-main-branch-sha run: | SHA=$(git rev-parse origin/main) echo "sha=$SHA" >> $GITHUB_OUTPUT - name: Get benchmark results from main branch id: cache uses: actions/cache/restore@v5 with: path: ./cache key: ${{ runner.os }}-benchmark - name: Run benchmark and store result run: | pixi run --frozen --environment gpu benchmark --batch-size 128 --benchmark-json output.json env: PY_COLORS: "1" - name: Compare benchmark results with main branch uses: benchmark-action/github-action-benchmark@v1.22.0 with: tool: 'pytest' output-file-path: output.json external-data-json-path: ./cache/benchmark-data.json save-data-file: false fail-on-alert: true summary-always: true comment-always: true alert-threshold: 150% github-token: ${{ secrets.GITHUB_TOKEN }} - name: Store benchmark result for main branch uses: benchmark-action/github-action-benchmark@v1.22.0 if: ${{ github.ref_name == 'main' }} with: tool: 'pytest' output-file-path: output.json external-data-json-path: ./cache/benchmark-data.json save-data-file: true fail-on-alert: false summary-always: true comment-always: true alert-threshold: 150% github-token: ${{ secrets.GITHUB_TOKEN }} - name: Publish Benchmark Results to GitHub Pages uses: benchmark-action/github-action-benchmark@v1.22.0 if: ${{ github.ref_name == 'main' }} with: tool: 'pytest' output-file-path: output.json benchmark-data-dir-path: "benchmarks" fail-on-alert: false github-token: ${{ secrets.GITHUB_TOKEN }} comment-on-alert: true summary-always: true save-data-file: true alert-threshold: "150%" auto-push: ${{ github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' }} - name: Update Benchmark Results cache uses: actions/cache/save@v5 if: ${{ github.ref_name == 'main' }} with: path: ./cache key: ${{ runner.os }}-benchmark ================================================ FILE: .github/workflows/pixi.yml ================================================ name: Pixi permissions: contents: write pull-requests: write on: workflow_dispatch: schedule: # Execute at 5am UTC on the first day of the month. - cron: '0 5 1 * *' jobs: pixi-update: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v6 with: lfs: true - name: Set up pixi uses: prefix-dev/setup-pixi@v0.9.5 with: run-install: false - name: Install pixi-diff-to-markdown run: pixi global install pixi-diff-to-markdown - name: Update pixi lockfile and generate diff run: | set -o pipefail pixi update --json | pixi exec pixi-diff-to-markdown --explicit-column > diff.md - name: Test project against updated pixi run: pixi run --environment default test env: PY_COLORS: "1" JAX_PLATFORM_NAME: cpu - name: Commit and push changes run: echo "BRANCH_NAME=update-pixi-$(date +'%Y%m%d%H%M%S')" >> $GITHUB_ENV - name: Create pull request uses: peter-evans/create-pull-request@v8 with: token: ${{ secrets.GITHUB_TOKEN }} commit-message: Update `pixi.lock` title: Update `pixi` lockfile body-path: diff.md branch: ${{ env.BRANCH_NAME }} base: main labels: pixi add-paths: pixi.lock delete-branch: true committer: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> author: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> ================================================ FILE: .github/workflows/read_the_docs.yml ================================================ name: Read the Docs PR on: pull_request_target: types: - opened permissions: pull-requests: write jobs: documentation-links: runs-on: ubuntu-latest steps: - uses: readthedocs/actions/preview@v1 with: project-slug: "jaxsim" project-language: "" ================================================ FILE: .gitignore ================================================ # IDEs .idea* .vscode/ # Matlab *.m~ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ docs/_collections/ docs/modules/_autosummary/ docs/modules/generated docs/sg_execution_times.rst # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # dynamic version src/jaxsim/_version.py # ruff .ruff_cache/ # pixi environments .pixi # data .mp4 .png ================================================ FILE: .pre-commit-config.yaml ================================================ ci: autofix_prs: false autoupdate_schedule: quarterly submodules: false default_language_version: python: python3 repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v6.0.0 hooks: - id: check-ast - id: check-merge-conflict - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - id: check-toml - id: check-added-large-files args: ["--maxkb=2000"] - repo: https://github.com/psf/black-pre-commit-mirror rev: 26.3.1 hooks: - id: black args: ["--check", "--diff"] - repo: https://github.com/pycqa/isort rev: 8.0.1 hooks: - id: isort args: ["--check", "--diff"] - repo: https://github.com/pre-commit/pygrep-hooks rev: v1.10.0 hooks: - id: rst-backticks - id: rst-directive-colons - id: rst-inline-touching-normal - repo: https://github.com/codespell-project/codespell rev: v2.4.2 hooks: - id: codespell args: ["-S", "*.lock"] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.15.9 hooks: - id: ruff - repo: https://github.com/kynan/nbstripout rev: 0.9.1 hooks: - id: nbstripout ================================================ FILE: .readthedocs.yaml ================================================ version: "2" build: os: ubuntu-24.04 tools: python: "mambaforge-23.11" conda: environment: environment.yml python: install: - method: pip path: . sphinx: configuration: docs/conf.py formats: all ================================================ FILE: CITATION.cff ================================================ cff-version: 1.2.0 title: JaxSim message: "If you use this software, please cite the paper." type: software authors: - family-names: Ferretti given-names: Filippo Luca affiliation: "Generative Bionics" - family-names: Ferigo given-names: Diego affiliation: "Robotics & AI Institute" - family-names: Croci given-names: Alessandro affiliation: "NEURA Robotics" - family-names: Sartore given-names: Carlotta affiliation: "Generative Bionics" - family-names: Younis given-names: Omar G. affiliation: "Quebec AI Institute" - family-names: Traversaro given-names: Silvio affiliation: "Generative Bionics" - family-names: Pucci given-names: Daniele affiliation: "Generative Bionics" repository-code: "https://github.com/gbionics/jaxsim" license: BSD-3-Clause preferred-citation: type: article title: "Contact-Aware Morphology Optimization via Physically Consistent Differentiable Simulation" authors: - family-names: Ferretti given-names: Filippo Luca - family-names: Ferigo given-names: Diego - family-names: Croci given-names: Alessandro - family-names: Sartore given-names: Carlotta - family-names: Younis given-names: Omar G. - family-names: Traversaro given-names: Silvio - family-names: Pucci given-names: Daniele journal: "IEEE Robotics and Automation Letters" year: 2026 doi: "10.1109/LRA.2026.3678125" ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to JAXsim :rocket: Hello Contributor, We're thrilled that you're considering contributing to JAXsim! Here's a brief guide to help you seamlessly become a part of our project. ## Development Environment :hammer_and_wrench: Make sure your development environment is set up. Follow the installation instructions in the [README](./README.md) to get JAXsim and its dependencies up and running. To ensure consistency and maintain code quality, we recommend using the pre-commit hook with the following configuration. This will help catch issues before they become a part of the project. ### Setting Up Pre-commit Hook :fishing_pole_and_fish: `pre-commit` is a tool that manages pre-commit hooks for your project. It will run checks on your code before you commit it, ensuring that it meets the project's standards. First, install `pre-commit` if you haven't already: ```bash pip install pre-commit ``` Then, run the following command to install the hooks: ```bash pre-commit install ``` ### Using Pre-commit Hook :vertical_traffic_light: Before making any commits, the pre-commit hook will automatically run. If it finds any issues, it will prevent the commit and provide instructions on how to fix them. To get your commit through without fixing the issues, use the `--no-verify` flag: ```bash git commit -m "Your commit message" --no-verify ``` To manually run the pre-commit hook at any time, use: ```bash pre-commit run --all-files ``` ## Making Changes :construction: Before submitting a pull request, create an issue to discuss your changes if major changes are involved. This helps us understand your needs and provide feedback. Clearly describe your pull request, referencing any related issues. Follow the [PEP 8](https://peps.python.org/pep-0008/) style guide and include relevant tests. ## Testing :test_tube: Your code will be tested with the CI/CD pipeline before merging. Feel free to add new ones or update the existing tests in the [workflows](./.github/workflows) folder to cover your changes. ## Documentation :book: Update the documentation in the [docs](./docs) folder and the [README](./README.md) to reflect your changes, if necessary. There is no need to build the documentation locally; it will be automatically built and deployed with your pull request, where a preview link will be provided. ## Code Review :eyes: Expect feedback during the code review process. Address comments and make necessary changes. This collaboration ensures quality. Please keep the commit history clean, or squash commits if necessary. ## License :scroll: JAXsim is under the [BSD 3-Clause License](./LICENSE). By contributing, you agree to the same license. Thank you for contributing to JAXsim! Your efforts are appreciated. ================================================ FILE: LICENSE ================================================ BSD 3-Clause License Copyright (c) 2022, Artificial and Mechanical Intelligence All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: README.md ================================================ # JaxSim **JaxSim** is a **differentiable physics engine** built with JAX, tailored for co-design and robotic learning applications.


## Features - Physically consistent differentiability w.r.t. hardware parameters. - Closed chain dynamics support. - Reduced-coordinate physics engine for **fixed-base** and **floating-base** robots. - Fully Python-based, leveraging [JAX][jax] following a functional programming paradigm. - Seamless execution on CPUs, GPUs, and TPUs. - Supports JIT compilation and automatic vectorization for high performance. - Compatible with SDF models and URDF (via [sdformat][sdformat] conversion). > [!WARNING] > This project is still experimental. APIs may change between releases without notice. > [!NOTE] > JaxSim currently focuses on locomotion applications. > Only contacts between bodies and smooth ground surfaces are supported. ## How to use it ```python import pathlib import icub_models import jax.numpy as jnp import jaxsim.api as js # Load the iCub model model_path = icub_models.get_model_file("iCubGazeboV2_5") joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch', 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch', 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch', 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll', 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll') # Build and reduce the model model_description = pathlib.Path(model_path) full_model = js.model.JaxSimModel.build_from_model_description( model_description=model_description, ) model = js.model.reduce(model=full_model, considered_joints=joints) # Get the number of degrees of freedom ndof = model.dofs() # Initialize data and simulation # Note that the default data representation is mixed velocity representation data = js.data.JaxSimModelData.build( model=model, base_position=jnp.array([0.0, 0.0, 1.0]) ) T = jnp.arange(start=0, stop=1.0, step=model.time_step) tau = jnp.zeros(ndof) # Simulate for _ in T: data = js.model.step( model=model, data=data, link_forces=None, joint_force_references=tau ) ``` Check the example folder for additional use cases! [jax]: https://github.com/google/jax/ [sdformat]: https://github.com/gazebosim/sdformat [notation]: https://research.tue.nl/en/publications/multibody-dynamics-notation-version-2 [passive_viewer_mujoco]: https://mujoco.readthedocs.io/en/stable/python.html#passive-viewer ## Installation
With conda You can install the project using [`conda`][conda] as follows: ```bash conda install jaxsim -c conda-forge ``` GPU support for JAX will be automatically installed if a compatible GPU is detected.
With pixi > ### Note > The minimum version of `pixi` required is `0.39.0`. Since the `pixi.lock` file is stored using Git LFS, make sure you have [Git LFS](https://github.com/git-lfs/git-lfs/blob/main/INSTALLING.md) installed and properly configured on your system before installation. After cloning the repository, run: ```bash git lfs install && git lfs pull ``` This ensures all LFS-tracked files are properly downloaded before you proceed with the installation. You can add the `jaxsim` dependency in your [`pixi`][pixi] project as follows: ```bash pixi add jaxsim ``` If you are on Linux and you want to use a `cuda`-powered version of `jax`, remember to add the appropriate line in the [`system-requirements`](https://pixi.sh/latest/reference/pixi_manifest/#the-system-requirements-table) table, i.e. adding ~~~toml [system-requirements] cuda = "13" ~~~ if you are using a `pixi.toml` file or ~~~toml [tool.pixi.system-requirements] cuda = "13" ~~~ if you are using a `pyproject.toml` file.
With pip You can install the project using [`pypa/pip`][pip], preferably in a [virtual environment][venv], as follows: ```bash pip install jaxsim ``` Check [`pyproject.toml`](pyproject.toml) for the complete list of optional dependencies. You can obtain a full installation using `jaxsim[all]`. If you need URDF support, follow the [official instructions](https://gazebosim.org/docs) to install Gazebo Sim on your operating system, making sure to obtain `sdformat ≥ 13.0` and `gz-tools ≥ 2.0`. You don't need to install the entire Gazebo Sim suite. For example, on Ubuntu, it is sufficient to install the `libsdformat*` and `gz-tools2` packages. If you need GPU support, follow the official [installation instructions][jax_gpu] of JAX.
Contributors installation (with conda) If you want to contribute to the project, we recommend creating the following `jaxsim` conda environment first: ```bash conda env create -f environment.yml ``` Then, activate the environment and install the project in editable mode: ```bash conda activate jaxsim pip install --no-deps -e . ```
Contributors installation (with pixi) > ### Note > The minimum version of `pixi` required is `0.39.0`. Since the `pixi.lock` file is stored using Git LFS, make sure you have [Git LFS](https://github.com/git-lfs/git-lfs/blob/main/INSTALLING.md) installed and properly configured on your system before installation. After cloning the repository, run: ```bash git lfs install && git lfs pull ``` This ensures all LFS-tracked files are properly downloaded before you proceed with the installation. You can install the default dependencies of the project using [`pixi`][pixi] as follows: ```bash pixi install ``` See `pixi task list` for a list of available tasks.
[conda]: https://anaconda.org/ [pip]: https://github.com/pypa/pip/ [pixi]: https://pixi.sh/ [venv]: https://docs.python.org/3/tutorial/venv.html [jax_gpu]: https://github.com/google/jax/#installation ## Documentation The JaxSim API documentation is available at [jaxsim.readthedocs.io][readthedocs]. [readthedocs]: https://jaxsim.readthedocs.io/ ## Additional features Jaxsim can also be used as a multi-body dynamics library! With full support for automatic differentiation of RBDAs (forwards and reverse mode) and automatic differentiation against both kinematic and dynamic parameters. ### Using JaxSim as a multibody dynamics library ```python import pathlib import icub_models import jax.numpy as jnp import jaxsim.api as js # Load the iCub model model_path = icub_models.get_model_file("iCubGazeboV2_5") joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch', 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch', 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch', 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll', 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll') # Build and reduce the model model_description = pathlib.Path(model_path) full_model = js.model.JaxSimModel.build_from_model_description( model_description=model_description, ) model = js.model.reduce(model=full_model, considered_joints=joints) # Initialize model data data = js.data.JaxSimModelData.build( model=model, base_position=jnp.array([0.0, 0.0, 1.0]), ) # Frame and dynamics computations frame_index = js.frame.name_to_idx(model=model, frame_name="l_foot") # Frame transformation W_H_F = js.frame.transform( model=model, data=data, frame_index=frame_index ) # Frame Jacobian W_J_F = js.frame.jacobian( model=model, data=data, frame_index=frame_index ) # Dynamics properties M = js.model.free_floating_mass_matrix(model=model, data=data) # Mass matrix h = js.model.free_floating_bias_forces(model=model, data=data) # Bias forces g = js.model.free_floating_gravity_forces(model=model, data=data) # Gravity forces C = js.model.free_floating_coriolis_matrix(model=model, data=data) # Coriolis matrix # Print dynamics results print(f"{M.shape=} \n{h.shape=} \n{g.shape=} \n{C.shape=}") ``` ## Credits The RBDAs are based on the theory of the [Rigid Body Dynamics Algorithms][RBDA] book by Roy Featherstone. The algorithms and some simulation features were inspired by its accompanying [code][spatial_v2]. [RBDA]: https://link.springer.com/book/10.1007/978-1-4899-7560-7 [spatial_v2]: http://royfeatherstone.org/spatial/index.html#spatial-software The development of JaxSim started in late 2021, inspired by early versions of [`google/brax`][brax]. At that time, Brax was implemented in maximal coordinates, and we wanted a physics engine in reduced coordinates. We are grateful to the Brax team for their work and for showing the potential of [JAX][jax] in this field. Brax v2 was later implemented with reduced coordinates, following an approach comparable to JaxSim. The development then shifted to [MJX][mjx], which provides a JAX-based implementation of the Mujoco APIs. The main differences between MJX/Brax and JaxSim are as follows: - JaxSim supports out-of-the-box all SDF models with [Pose Frame Semantics][PFS]. - JaxSim only supports collisions between points rigidly attached to bodies and a compliant ground surface. [brax]: https://github.com/google/brax [mjx]: https://mujoco.readthedocs.io/en/3.0.0/mjx.html [PFS]: http://sdformat.org/tutorials?tut=pose_frame_semantics ## Contributing We welcome contributions from the community. Please read the [contributing guide](./CONTRIBUTING.md) to get started. ## Citing If you use JaxSim in your work, please cite the following paper: ```bibtex @article{ferretti_contact_aware_2026, author = {Filippo Luca Ferretti and Diego Ferigo and Alessandro Croci and Carlotta Sartore and Omar G. Younis and Silvio Traversaro and Daniele Pucci}, title = {Contact-Aware Morphology Optimization via Physically Consistent Differentiable Simulation}, journal = {IEEE Robotics and Automation Letters}, year = {2026}, doi = {10.1109/LRA.2026.3678125} } ``` ## People | Authors | Maintainer | |:------:|:-----------:| | [][df] [][ff] | [][ff] | [df]: https://github.com/diegoferigo [ff]: https://github.com/flferretti ## License [BSD3](https://choosealicense.com/licenses/bsd-3-clause/) ================================================ FILE: docs/Makefile ================================================ SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build SPHINXPROJ = JAXsim # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/conf.py ================================================ # Configuration file for the Sphinx documentation builder. import os import sys if os.environ.get("READTHEDOCS"): checkout_name = os.path.basename(os.path.dirname(os.path.realpath(__file__))) os.environ["CONDA_PREFIX"] = os.path.realpath( os.path.join("..", "..", "conda", checkout_name) ) import jaxsim # -- Version information sys.path.insert(0, os.path.abspath(".")) sys.path.insert(0, os.path.abspath("../")) sys.path.insert(0, os.path.abspath("../../")) module_path = os.path.abspath("../src/") sys.path.insert(0, module_path) __version__ = jaxsim._version.__version__ # -- Project information project = "JAXsim" copyright = "2022, Artificial and Mechanical Intelligence" author = "Artificial and Mechanical Intelligence" release = version = __version__ # -- General configuration extensions = [ "sphinx.ext.duration", "sphinx.ext.doctest", "sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx.ext.intersphinx", "sphinx.ext.mathjax", "sphinx.ext.ifconfig", "sphinx.ext.viewcode", "sphinx_rtd_theme", "sphinx.ext.napoleon", "sphinx_autodoc_typehints", "sphinx_multiversion", "myst_nb", "sphinx_gallery.gen_gallery", "sphinxcontrib.collections", "sphinx_design", ] # -- Options for intersphinx extension language = "en" html_theme = "sphinx_book_theme" templates_path = ["_templates"] html_title = f"JAXsim {version}" master_doc = "index" autodoc_typehints_format = "short" autodoc_typehints = "description" autosummary_generate = True epub_show_urls = "footnote" # Enable postponed evaluation of annotations (PEP 563) autodoc_type_aliases = { "jaxsim.typing.PyTree": "jaxsim.typing.PyTree", "jaxsim.typing.Vector": "jaxsim.typing.Vector", "jaxsim.typing.Matrix": "jaxsim.typing.Matrix", "jaxsim.typing.Array": "jaxsim.typing.Array", "jaxsim.typing.Int": "jaxsim.typing.Int", "jaxsim.typing.Bool": "jaxsim.typing.Bool", "jaxsim.typing.Float": "jaxsim.typing.Float", "jaxsim.typing.ScalarLike": "jaxsim.typing.ScalarLike", "jaxsim.typing.ArrayLike": "jaxsim.typing.ArrayLike", "jaxsim.typing.VectorLike": "jaxsim.typing.VectorLike", "jaxsim.typing.MatrixLike": "jaxsim.typing.MatrixLike", "jaxsim.typing.IntLike": "jaxsim.typing.IntLike", "jaxsim.typing.BoolLike": "jaxsim.typing.BoolLike", "jaxsim.typing.FloatLike": "jaxsim.typing.FloatLike", } # -- Options for sphinx-collections collections = { "examples": {"driver": "copy_folder", "source": "../examples/", "ignore": "assets"} } # -- Options for sphinx-gallery ---------------------------------------------- sphinx_gallery_conf = { "examples_dirs": "../examples", "gallery_dirs": "../generated_examples/", "doc_module": "jaxsim", } # -- Options for myst ------------------------------------------------------- myst_enable_extensions = [ "amsmath", "dollarmath", ] nb_execution_mode = "auto" nb_execution_raise_on_error = True nb_render_image_options = { "scale": "60", } nb_execution_timeout = 180 source_suffix = [".rst", ".md", ".ipynb"] # Ignore header warnings suppress_warnings = ["myst.header"] ================================================ FILE: docs/examples.rst ================================================ .. _collections: Example Notebooks ================= .. toctree:: :glob: :hidden: :maxdepth: 1 _collections/examples/README.md .. raw:: html
.. only:: html :doc:`_collections/examples/jaxsim_as_physics_engine` .. raw:: html
JaxSim as a hardware-accelerated parallel physics engine
.. only:: html :doc:`_collections/examples/jaxsim_as_physics_engine_advanced` .. raw:: html
JaxSim as a hardware-accelerated parallel physics engine [Advanced]
.. only:: html :doc:`_collections/examples/jaxsim_as_multibody_dynamics_library` .. raw:: html
JaxSim as a multibody dynamics library
.. only:: html :doc:`_collections/examples/jaxsim_for_robot_controllers` .. raw:: html
JaxSim for developing closed-loop robot controllers
================================================ FILE: docs/guide/configuration.rst ================================================ Configuration ============= JaxSim utilizes environment variables for application configuration. Below is a detailed overview of the various configuration categories and their respective variables. Collision Dynamics ~~~~~~~~~~~~~~~~~~ Environment variables starting with ``JAXSIM_COLLISION_`` are used to configure collision dynamics. The available variables are: - ``JAXSIM_COLLISION_SPHERE_POINTS``: Specifies the number of collision points to approximate the sphere. *Default:* ``50``. - ``JAXSIM_COLLISION_MESH_ENABLED``: Enables or disables mesh-based collision detection. *Default:* ``False``. - ``JAXSIM_COLLISION_USE_BOTTOM_ONLY``: Limits collision detection to only the bottom half of the box or sphere. *Default:* ``False``. .. note:: The bottom half is defined as the half of the box or sphere with the lowest z-coordinate in the collision link frame. Testing ~~~~~~~ For testing configurations, environment variables beginning with ``JAXSIM_TEST_`` are used. The following variables are available: - ``JAXSIM_TEST_SEED``: Defines the seed for the random number generator. *Default:* ``0``. - ``JAXSIM_TEST_AD_ORDER``: Specifies the gradient order for automatic differentiation tests. *Default:* ``1``. - ``JAXSIM_TEST_FD_STEP_SIZE``: Sets the step size for finite difference tests. *Default:* the cube root of the machine epsilon. Joint Dynamics ~~~~~~~~~~~~~~ Joint dynamics are configured using environment variables starting with ``JAXSIM_JOINT_``. Available variables include: - ``JAXSIM_JOINT_POSITION_LIMIT_DAMPER``: Overrides the damper value for joint position limits of the SDF model. - ``JAXSIM_JOINT_POSITION_LIMIT_SPRING``: Overrides the spring value for joint position limits of the SDF model. Logging and Exceptions ~~~~~~~~~~~~~~~~~~~~~~ The logging and exceptions configurations is controlled by the following environment variables: - ``JAXSIM_LOGGING_LEVEL``: Determines the logging level. *Default:* ``DEBUG`` for development, ``WARNING`` for production. - ``JAXSIM_ENABLE_EXCEPTIONS``: Enables the runtime checks and exceptions. Note that enabling exceptions might lead to device-to-host transfer of data, increasing the computational time required. *Default:* ``False``. .. note:: Runtime exceptions are disabled by default on TPU. ================================================ FILE: docs/guide/install.rst ================================================ Installation ============ .. _installation: Prerequisites ------------- JAXsim requires Python 3.11 or later. Basic Installation ------------------ You can install the project with using `conda`_: .. code-block:: bash conda install jaxsim -c conda-forge Alternatively, you can use `pypa/pip`_, preferably in a `virtual environment`_: .. code-block:: bash pip install jaxsim Have a look to `pyproject.toml`_ for a complete list of optional dependencies. You can install all by using ``pip install "jaxsim[all]"``. .. note:: If you need GPU support, please follow the official `installation instruction`_ of JAX. .. _conda: https://anaconda.org/ .. _pyproject.toml: https://github.com/gbionics/jaxsim/blob/main/pyproject.toml .. _pypa/pip: https://github.com/pypa/pip/ .. _virtual environment: https://docs.python.org/3.8/tutorial/venv.html .. _installation instruction: https://github.com/google/jax/#installation ================================================ FILE: docs/index.rst ================================================ JAXsim ####### A scalable physics engine and multibody dynamics library implemented with JAX. With JIT batteries 🔋 .. note:: This simulator currently focuses on locomotion applications. Only contacts with ground are supported. Features -------- .. grid:: .. grid-item:: :columns: 12 12 12 6 .. card:: Performance :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal Physics engine in reduced coordinates implemented with JAX_. Compatibility with JIT compilation for increased performance and transparent support to execute logic on CPUs, GPUs, and TPUs. Parallel multi-body simulations on hardware accelerators for significantly increased throughput .. grid-item:: :columns: 12 12 12 6 .. card:: Model Parsing :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal Support for SDF models (and, upon conversion, URDF models). Revolute, prismatic, and fixed joints supported. .. grid-item:: :columns: 12 12 12 6 .. card:: Automatic Differentiation :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal Support for automatic differentiation of rigid body dynamics algorithms (RBDAs) for model-based robotics research. Soft contacts model supporting full friction cone and sticking / slipping transition. .. grid-item:: :columns: 12 12 12 6 .. card:: Complex Dynamics :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal JAXsim provides a variety of integrators for the simulation of multibody dynamics, including RK4, Heun, Euler, and more. Support of `multiple velocities representations `_. ---- .. toctree:: :hidden: guide/install guide/configuration examples .. toctree:: :hidden: :maxdepth: 2 :caption: JAXsim API modules/api modules/math modules/mujoco modules/parsers modules/rbda modules/typing modules/utils Examples -------- Explore and learn how to use the library through practical demonstrations available in the `examples `__ folder. Credits ------- The physics module of JAXsim is based on the theory of the `Rigid Body Dynamics Algorithms `_ book by Roy Featherstone. We structured part of our logic following its accompanying `code `_. The physics engine is developed entirely in Python using JAX_. The inspiration for developing JAXsim originally stemmed from early versions of Brax_. Here below we summarize the differences between the projects: - JAXsim simulates multibody dynamics in reduced coordinates, while :code:`brax v1` uses maximal coordinates. - The new v2 APIs of brax (and the new MJX_) were then implemented in reduced coordinates, following an approach comparable to JAXsim, with major differences in contact handling. - The rigid-body algorithms used in JAXsim allow to efficiently compute quantities based on the Euler-Poincarè formulation of the equations of motion, necessary for model-based robotics research. - JAXsim supports SDF (and, indirectly, URDF) models, assuming the model is described with the recent `Pose Frame Semantics `_. - Contrarily to brax, JAXsim only supports collision detection between bodies and a compliant ground surface. - The RBDAs of JAXsim support automatic differentiation, but this functionality has not been thoroughly tested. People ------ Authors ''''''' `Diego Ferigo `_ `Filippo Luca Ferretti `_ Maintainers ''''''''''' `Filippo Luca Ferretti `_ `Alessandro Croci `_ License ------- `BSD3 `_ .. _Brax: https://github.com/google/brax .. _MJX: https://mujoco.readthedocs.io/en/3.0.0/mjx.html .. _JAX: https://github.com/google/jax ================================================ FILE: docs/make.bat ================================================ @ECHO OFF pushd %~dp0 REM if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=. set BUILDDIR=_build set SPHINXPROJ=JaxSim if "%1" == "" goto help %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.http://sphinx-doc.org/ exit /b 1 ) %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end popd ================================================ FILE: docs/modules/api.rst ================================================ Functional API ============== .. currentmodule:: jaxsim.api .. autosummary:: :toctree: _autosummary model data contact kin_dyn_parameters integrators joint link frame com ode references actuation_model common Model ~~~~~ .. automodule:: jaxsim.api.model :members: :no-index: .. automodule:: jaxsim.api.actuation_model :members: :no-index: Data ~~~~ .. automodule:: jaxsim.api.data :members: :no-index: Contact ~~~~~~~ .. automodule:: jaxsim.api.contact :members: :no-index: KinDynParameters ~~~~~~~~~~~~~~~~ .. automodule:: jaxsim.api.kin_dyn_parameters :members: :no-index: Joint ~~~~~ .. automodule:: jaxsim.api.joint :members: :no-index: Link ~~~~~ .. automodule:: jaxsim.api.link :members: :no-index: Frame ~~~~~ .. automodule:: jaxsim.api.frame :members: :no-index: CoM ~~~ .. automodule:: jaxsim.api.com :members: :no-index: Integration ~~~~~~~~~~~ .. automodule:: jaxsim.api.integrators :members: :no-index: .. automodule:: jaxsim.api.ode :members: :no-index: References ~~~~~~~~~~ .. automodule:: jaxsim.api.references :members: :no-index: Common ~~~~~~ .. autoclass:: jaxsim.api.common.VelRepr :members: .. autoclass:: jaxsim.api.common.ModelDataWithVelocityRepresentation :members: ================================================ FILE: docs/modules/math.rst ================================================ Math ==== .. currentmodule:: jaxsim.math .. automodule:: jaxsim.math.adjoint :members: :undoc-members: .. automodule:: jaxsim.math.cross :members: :undoc-members: .. automodule:: jaxsim.math.inertia :members: :undoc-members: .. automodule:: jaxsim.math.quaternion :members: :undoc-members: .. automodule:: jaxsim.math.rotation :members: :undoc-members: .. automodule:: jaxsim.math.skew :members: :undoc-members: ================================================ FILE: docs/modules/mujoco.rst ================================================ MuJoCo Visualizer ================== JAXsim provides a simple interface with MuJoCo's visualizer. The visualizer is a separate process that communicates with the main simulation process. This allows for the simulation to run at full speed while the visualizer can run at a different frame rate. .. currentmodule:: jaxsim.mujoco Loaders ~~~~~~~ .. automodule:: jaxsim.mujoco.loaders :members: Model ~~~~~ .. automodule:: jaxsim.mujoco.model :members: Visualizer ~~~~~~~~~~ .. automodule:: jaxsim.mujoco.visualizer :members: ================================================ FILE: docs/modules/parsers.rst ================================================ Parsers ======= .. automodule:: jaxsim.parsers.descriptions.collision :members: .. automodule:: jaxsim.parsers.descriptions.joint :members: .. automodule:: jaxsim.parsers.descriptions.link :members: .. automodule:: jaxsim.parsers.descriptions.model :members: ================================================ FILE: docs/modules/rbda.rst ================================================ Rigid Body Dynamics Algorithms ============================== This module provides a set of algorithms for rigid body dynamics. .. currentmodule:: jaxsim.rbda .. autosummary:: :toctree: _autosummary aba collidable_points contacts.soft contacts.rigid contacts.relaxed_rigid crba forward_kinematics jacobian utils Collision Detection ~~~~~~~~~~~~~~~~~~~ .. automodule:: jaxsim.rbda.collidable_points :members: :no-index: Contact Models ~~~~~~~~~~~~~~ .. automodule:: jaxsim.rbda.contacts.soft :members: :no-index: .. automodule:: jaxsim.rbda.contacts.rigid :members: :no-index: .. automodule:: jaxsim.rbda.contacts.relaxed_rigid :members: :no-index: Utilities ~~~~~~~~~ .. automodule:: jaxsim.rbda.utils :members: :no-index: ================================================ FILE: docs/modules/typing.rst ================================================ Typing ====== .. currentmodule:: jaxsim.typing .. autosummary:: PyTree Matrix Bool Int Float Vector BoolLike FloatLike IntLike ArrayLike VectorLike MatrixLike ================================================ FILE: docs/modules/utils.rst ================================================ Utils ===== .. automodule:: jaxsim.utils :members: :inherited-members: .. autoclass:: jaxsim.utils.JaxsimDataclass :members: :inherited-members: ================================================ FILE: environment.yml ================================================ name: jaxsim channels: - conda-forge dependencies: # =========================== # Dependencies from setup.cfg # =========================== - python >= 3.12.0 - coloredlogs - jax >= 0.4.34 - jaxlib >= 0.4.34 - jaxlie >= 1.3.0 - jax-dataclasses >= 1.4.0 - optax >= 0.2.3 - pptree - qpax - rod >= 0.3.3 - trimesh - typing_extensions # python<3.12 # ==================================== # Optional dependencies from setup.cfg # ==================================== # [testing] - chex - idyntree >= 12.2.1 - pytest - pytest-benchmark - pytest-icdiff - robot_descriptions >= 1.16.0 - icub-models # [viz] - lxml - mediapy - mujoco >= 3.0.0 - scipy >= 1.14.0 # ========================== # Documentation dependencies # ========================== - cachecontrol - filecache - jinja2 - myst-nb - pip - sphinx - sphinx-autodoc-typehints - sphinx-book-theme - sphinx-copybutton - sphinx-design - sphinx_fontawesome - sphinx-gallery - sphinx-jinja2-compat - sphinx-multiversion - sphinx_rtd_theme - sphinx-toolbox - icub-models # ======================================== # Other dependencies for GitHub Codespaces # ======================================== - ipython - pip: - sphinx-collections # TODO (flferretti): PR to conda-forge ================================================ FILE: examples/.gitattributes ================================================ # GitHub syntax highlighting pixi.lock linguist-language=YAML ================================================ FILE: examples/.gitignore ================================================ # pixi environments .pixi ================================================ FILE: examples/README.md ================================================ # JaxSim Examples This folder contains Jupyter notebooks that demonstrate the practical usage of JaxSim. ## Featured examples | Notebook | Google Colab | Description | | :--- | :---: | :--- | | [`jaxsim_as_multibody_dynamics_library`](./jaxsim_as_multibody_dynamics_library.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_as_multibody_dynamics] | An example demonstrating how to use JaxSim as a multibody dynamics library. | | [`jaxsim_as_physics_engine.ipynb`](./jaxsim_as_physics_engine.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_as_physics_engine] | An example demonstrating how to simulate vectorized models in parallel. | | [`jaxsim_as_physics_engine_advanced.ipynb`](./jaxsim_as_physics_engine_advanced.ipynb) | [![Open In Colab][colab_badge]][jaxsim_as_physics_engine_advanced] | An example showcasing advanced JaxSim usage, such as customizing the integrator, contact model, and more. | | [`jaxsim_for_robot_controllers.ipynb`](./jaxsim_for_robot_controllers.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_closed_loop] | A basic example showing how to simulate a PD controller with gravity compensation for a 2-DOF cart-pole. | [colab_badge]: https://colab.research.google.com/assets/colab-badge.svg [ipynb_jaxsim_closed_loop]: https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_for_robot_controllers.ipynb [ipynb_jaxsim_as_physics_engine]: https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_as_physics_engine.ipynb [jaxsim_as_physics_engine_advanced]: https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_as_physics_engine_advanced.ipynb [ipynb_jaxsim_as_multibody_dynamics]: https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_as_multibody_dynamics_library.ipynb ## How to run the examples You can run the JaxSim examples with hardware acceleration in two ways. ### Option 1: Google Colab (recommended) The easiest way is to use the provided Google Colab links to run the notebooks in a hosted environment with no setup required. ### Option 2: Local execution with `pixi` To run the examples locally, first install `pixi` following the [official documentation][pixi_installation]: [pixi_installation]: https://pixi.sh/#installation ```bash curl -fsSL https://pixi.sh/install.sh | bash ``` Then, from the repository's root directory, execute the example notebooks using: ```bash pixi run examples ``` This command will automatically handle all necessary dependencies and run the examples in a self-contained environment. ================================================ FILE: examples/assets/build_cartpole_urdf.py ================================================ import os if "ROD_LOGGING_LEVEL" not in os.environ: os.environ["ROD_LOGGING_LEVEL"] = "WARNING" import numpy as np import rod.kinematics.tree_transforms from rod.builder import primitives if __name__ == "__main__": # ================ # Model parameters # ================ # Rail parameters. rail_height = 1.2 rail_length = 5.0 rail_radius = 0.005 rail_mass = 5.0 # Cart parameters. cart_mass = 1.0 cart_size = (0.1, 0.2, 0.05) # Pole parameters. pole_mass = 0.5 pole_length = 1.0 pole_radius = 0.005 # ======================== # Create the link builders # ======================== rail_builder = primitives.CylinderBuilder( name="rail", mass=rail_mass, radius=rail_radius, length=rail_length, ) cart_builder = primitives.BoxBuilder( name="cart", mass=cart_mass, x=cart_size[0], y=cart_size[1], z=cart_size[2], ) pole_builder = primitives.CylinderBuilder( name="pole", mass=pole_mass, radius=pole_radius, length=pole_length, ) # ================= # Create the joints # ================= world_to_rail = rod.Joint( name="world_to_rail", type="fixed", parent="world", child=rail_builder.name, pose=primitives.PrimitiveBuilder.build_pose( relative_to="world", ), ) linear = rod.Joint( name="linear", type="prismatic", parent=rail_builder.name, child=cart_builder.name, pose=primitives.PrimitiveBuilder.build_pose( relative_to=rail_builder.name, pos=np.array([0, 0, rail_height]), ), axis=rod.Axis( xyz=rod.Xyz(xyz=[0, 1, 0]), limit=rod.Limit( upper=(rail_length / 2 - cart_size[1] / 2), lower=-(rail_length / 2 - cart_size[1] / 2), effort=500.0, velocity=10.0, ), ), ) pivot = rod.Joint( name="pivot", type="continuous", parent=cart_builder.name, child=pole_builder.name, pose=primitives.PrimitiveBuilder.build_pose( relative_to=cart_builder.name, ), axis=rod.Axis( xyz=rod.Xyz(xyz=[1, 0, 0]), limit=rod.Limit(), ), ) # ================ # Create the links # ================ rail_elements_pose = primitives.PrimitiveBuilder.build_pose( pos=np.array([0, 0, rail_height]), rpy=np.array([np.pi / 2, 0, 0]), ) rail = ( rail_builder.build_link( name=rail_builder.name, pose=primitives.PrimitiveBuilder.build_pose( relative_to=world_to_rail.name, ), ) .add_inertial(pose=rail_elements_pose) .add_visual(pose=rail_elements_pose) .add_collision(pose=rail_elements_pose) .build() ) cart = ( cart_builder.build_link( name=cart_builder.name, pose=primitives.PrimitiveBuilder.build_pose(relative_to=linear.name), ) .add_inertial() .add_visual() .add_collision() .build() ) pole_elements_pose = primitives.PrimitiveBuilder.build_pose( pos=np.array([0, 0, pole_length / 2]), ) pole = ( pole_builder.build_link( name=pole_builder.name, pose=primitives.PrimitiveBuilder.build_pose( relative_to=pivot.name, ), ) .add_inertial(pose=pole_elements_pose) .add_visual(pose=pole_elements_pose) .add_collision(pose=pole_elements_pose) .build() ) # =========== # Build model # =========== # Create ROD in-memory model. model = rod.Model( name="cartpole", canonical_link=rail.name, link=[ rail, cart, pole, ], joint=[ world_to_rail, linear, pivot, ], ) # Update the pose elements to be closer to those expected in URDF. model.switch_frame_convention( frame_convention=rod.FrameConvention.Urdf, explicit_frames=True ) # ============== # Get SDF string # ============== # Create the top-level SDF object. sdf = rod.Sdf(version="1.10", model=model) # Generate the SDF string. # sdf_string = sdf.serialize(pretty=True, validate=True) # =============== # Get URDF string # =============== import rod.urdf.exporter # Convert the SDF to URDF. urdf_string = rod.urdf.exporter.UrdfExporter( pretty=True, indent=" " ).to_urdf_string(sdf=sdf) # Print the URDF string. print(urdf_string) ================================================ FILE: examples/assets/cartpole.urdf ================================================ ================================================ FILE: examples/jaxsim_as_multibody_dynamics_library.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "DpLq0-lltwZ1" }, "source": [ "# `JaxSim` as a multibody dynamics library\n", "\n", "JaxSim was initially developed as a **hardware-accelerated physics engine**. Over time, it has evolved, adding new features to become a comprehensive **JAX-based multibody dynamics library**.\n", "\n", "In this notebook, you'll explore the main APIs for loading robot models and computing key quantities for applications such as control, planning, and more.\n", "\n", "A key advantage of JaxSim is its ability to create fully differentiable closed-loop systems, enabling end-to-end optimization. Combined with the flexibility to parameterize model kinematics and dynamics, JaxSim can serve as an excellent playground for robot learning applications.\n", "\n", "\n", " \"Open\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "rcEwprINtwZ3" }, "source": [ "## Prepare environment\n", "\n", "First, we need to install the necessary packages and import their resources." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "u4xL7dbBtwZ3", "outputId": "1a088e28-e005-4910-928c-cb641e589ab5" }, "outputs": [], "source": [ "# @title Imports and setup\n", "from IPython.display import clear_output\n", "import sys\n", "\n", "IS_COLAB = \"google.colab\" in sys.modules\n", "\n", "# Install JAX, sdformat, and other notebook dependencies.\n", "if IS_COLAB:\n", " !{sys.executable} -m pip install --pre -qU jaxsim\n", " !{sys.executable} -m pip install robot_descriptions>=1.16.0\n", " !apt install -qq lsb-release wget gnupg\n", " !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n", " !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n", " !apt -qq update\n", " !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n", "\n", " clear_output()\n", "\n", "import os\n", "import pathlib\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "import jaxsim.api as js\n", "import jaxsim.math\n", "from jaxsim import logging\n", "from jaxsim import VelRepr\n", "\n", "logging.set_logging_level(logging.LoggingLevel.WARNING)\n", "print(f\"Running on {jax.devices()}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "fN8Xg4QgtwZ4" }, "source": [ "## Robot model\n", "\n", "JaxSim allows loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files.\n", "\n", "In this example, we will use the [ErgoCub][ergocub] humanoid robot model. If you have a URDF/SDF file for your robot that is compatible with [`gazebosim/sdformat`][sdformat_github][1], it should work out-of-the-box with JaxSim.\n", "\n", "[sdformat]: http://sdformat.org/\n", "[urdf]: http://wiki.ros.org/urdf/\n", "[ergocub]: https://ergocub.eu/\n", "[sdformat_github]: https://github.com/gazebosim/sdformat\n", "\n", "---\n", "\n", "[1]: JaxSim validates robot descriptions using the command `gz sdf -p /path/to/file.urdf`. Ensure this command runs successfully on your file.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rB0BFxyPtwZ5" }, "outputs": [], "source": [ "# @title Fetch the URDF file\n", "\n", "try:\n", " os.environ[\"ROBOT_DESCRIPTION_COMMIT\"] = \"v0.7.7\"\n", "\n", " import robot_descriptions.ergocub_description\n", "\n", "finally:\n", " _ = os.environ.pop(\"ROBOT_DESCRIPTION_COMMIT\", None)\n", "\n", "model_description_path = pathlib.Path(\n", " robot_descriptions.ergocub_description.URDF_PATH.replace(\n", " \"ergoCubSN002\", \"ergoCubSN001\"\n", " )\n", ")\n", "\n", "clear_output()" ] }, { "cell_type": "markdown", "metadata": { "id": "jeTUZic8twZ5" }, "source": [ "### Create the model and its data\n", "\n", "The dynamics of a generic floating-base model are governed by the following equations of motion:\n", "\n", "$$\n", "M(\\mathbf{q}) \\dot{\\boldsymbol{\\nu}} + \\mathbf{h}(\\mathbf{q}, \\boldsymbol{\\nu}) = B \\boldsymbol{\\tau} + \\sum_{L_i \\in \\mathcal{L}} J_{W,L_i}^\\top(\\mathbf{q}) \\: \\mathbf{f}_i\n", ".\n", "$$\n", "\n", "Here, the system state is represented by:\n", "\n", "- $\\mathbf{q} = ({}^W \\mathbf{p}_B, \\, \\mathbf{s}) \\in \\text{SE}(3) \\times \\mathbb{R}^n$ is the generalized position.\n", "- $\\boldsymbol{\\nu} = (\\boldsymbol{v}_{W,B}, \\, \\boldsymbol{\\omega}_{W,B}, \\, \\dot{\\mathbf{s}}) \\in \\mathbb{R}^{6+n}$ is the generalized velocity.\n", "\n", "The inputs to the system are:\n", "\n", "- $\\boldsymbol{\\tau} \\in \\mathbb{R}^n$ are the joint torques.\n", "- $\\mathbf{f}_i \\in \\mathbb{R}^6$ is the 6D force applied to the link $L_i$.\n", "\n", "JaxSim exposes functional APIs to operate over the following two main data structures:\n", "\n", "- **`JaxSimModel`** stores all the constant information parsed from the model description.\n", "- **`JaxSimModelData`** holds the state of model.\n", "\n", "Additionally, JaxSim includes a utility class, **`JaxSimModelReferences`**, for managing and manipulating system inputs.\n", "\n", "---\n", "\n", "This notebook uses the notation summarized in the following report. Please refer to this document if you have any questions or if something is unclear.\n", "\n", "> Traversaro and Saccon, **Multibody dynamics notation**, 2019, [URL](https://pure.tue.nl/ws/portalfiles/portal/139293126/A_Multibody_Dynamics_Notation_Revision_2_.pdf)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WYgBAxU0twZ6" }, "outputs": [], "source": [ "# Create the model from the model description.\n", "# JaxSim removes all fixed joints by lumping together their parent and child links.\n", "full_model = js.model.JaxSimModel.build_from_model_description(\n", " model_description=model_description_path\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "DdaETmDStwZ6" }, "source": [ "It is often useful to work with only a subset of joints, referred to as the _considered joints_. JaxSim allows to reduce a model so that the computation of the rigid body dynamics quantities is simplified.\n", "\n", "By default, the positions of the removed joints are considered to be zero. If this is not the case, the `reduce` function accepts a dictionary `dict[str, float]` to specify custom joint positions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QuhG7Zv5twZ7" }, "outputs": [], "source": [ "model = js.model.reduce(\n", " model=full_model,\n", " considered_joints=tuple(\n", " j\n", " for j in full_model.joint_names()\n", " # Remove sensor joints.\n", " if \"camera\" not in j\n", " # Remove head and hands.\n", " and \"neck\" not in j\n", " and \"wrist\" not in j\n", " and \"thumb\" not in j\n", " and \"index\" not in j\n", " and \"middle\" not in j\n", " and \"ring\" not in j\n", " and \"pinkie\" not in j\n", " # Remove upper body.\n", " and \"torso\" not in j and \"elbow\" not in j and \"shoulder\" not in j\n", " ),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "RLvAit_i2ZiA", "outputId": "ea3954af-b9b9-46ac-d9cb-20b99b1eac94" }, "outputs": [], "source": [ "# Print model quantities.\n", "print(f\"Model name: {model.name()}\")\n", "print(f\"Number of links: {model.number_of_links()}\")\n", "print(f\"Number of joints: {model.number_of_joints()}\")\n", "\n", "print()\n", "print(f\"Links:\\n{model.link_names()}\")\n", "\n", "print()\n", "print(f\"Joints:\\n{model.joint_names()}\")\n", "\n", "print()\n", "print(f\"Frames:\\n{model.frame_names()}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Xp8V5on5twZ8", "outputId": "cc1564db-ae91-4dba-92c9-b8b87bd65f10" }, "outputs": [], "source": [ "# Create a random data object from the reduced model.\n", "data = js.data.random_model_data(model=model)\n", "\n", "# Print the default state.\n", "W_H_B, s = data.generalized_position\n", "ν = data.generalized_velocity\n", "\n", "print(f\"W_H_B: shape={W_H_B.shape}\\n{W_H_B}\\n\")\n", "print(f\"s: shape={s.shape}\\n{s}\\n\")\n", "print(f\"ν: shape={ν.shape}\\n{ν}\\n\") # noqa: RUF001" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "XLx3sv9VtwZ9", "outputId": "28f5f070-e37e-464e-d84e-2944cfdc28dc" }, "outputs": [], "source": [ "# Create a random link forces matrix.\n", "link_forces = jax.random.uniform(\n", " minval=-10.0,\n", " maxval=10.0,\n", " shape=(model.number_of_links(), 6),\n", " key=jax.random.PRNGKey(0),\n", ")\n", "\n", "# Create a random joint force references vector.\n", "# Note that these are called 'references' because the actual joint forces that\n", "# are actuated might differ due to effects like joint friction.\n", "joint_force_references = jax.random.uniform(\n", " minval=-10.0, maxval=10.0, shape=(model.dofs(),), key=jax.random.PRNGKey(0)\n", ")\n", "\n", "# Create the references object.\n", "references = js.references.JaxSimModelReferences.build(\n", " model=model,\n", " data=data,\n", " link_forces=link_forces,\n", " joint_force_references=joint_force_references,\n", ")\n", "\n", "print(f\"link_forces: shape={references.link_forces(model=model, data=data).shape}\")\n", "print(f\"joint_force_references: shape={references.joint_force_references(model=model).shape}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "AaG817vP4LfT" }, "source": [ "## Robot Kinematics\n", "\n", "JaxSim offers functional APIs for computing kinematic quantities:\n", "\n", "- **`jaxsim.api.model`**: vectorized functions operating on the whole model.\n", "- **`jaxsim.api.link`**: functions operating on individual links.\n", "- **`jaxsim.api.frame`**: functions operating on individual frames. \n", "\n", "Due to JAX limitations on vectorizable data types, many APIs operate on indices instead of names. Since using indices can be error prone, JaxSim provides conversion functions for both links:\n", "\n", "- **jaxsim.api.link.names_to_idxs()**\n", "- **jaxsim.api.link.idxs_to_names()**\n", "\n", "and frames: \n", "\n", "- **jaxsim.api.frame.names_to_idxs()**\n", "- **jaxsim.api.frame.idxs_to_names()**\n", "\n", "We recommend using names whenever possible to avoid hard-to-trace errors.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QxImwfZz7pz-" }, "outputs": [], "source": [ "# Find the index of a link.\n", "link_name = \"l_ankle_2\"\n", "link_index = js.link.name_to_idx(model=model, link_name=link_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "C22Iqu2i4G-I", "outputId": "94376151-177d-410f-f375-b7b8bd080992" }, "outputs": [], "source": [ "# @title Link Pose\n", "\n", "# Compute its pose w.r.t. the world frame through forward kinematics.\n", "W_H_L = js.link.transform(model=model, data=data, link_index=link_index)\n", "\n", "print(f\"Transform of '{link_name}': shape={W_H_L.shape}\\n{W_H_L}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "DnSpE_f97RkX", "outputId": "a3f6b535-4ae5-49f4-8921-7fe4dda5debb" }, "outputs": [], "source": [ "# @title Link 6D Velocity\n", "\n", "# JaxSim allows to select the so-called representation of the frame velocity.\n", "L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Body)\n", "LW_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Mixed)\n", "W_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Inertial)\n", "\n", "print(f\"Body-fixed velocity L_v_WL={L_v_WL}\")\n", "print(f\"Mixed velocity: LW_v_WL={LW_v_WL}\")\n", "print(f\"Inertial-fixed velocity: W_v_WL={W_v_WL}\")\n", "\n", "# These can also be computed passing through the link free-floating Jacobian.\n", "# This type of Jacobian has a input velocity representation that corresponds\n", "# the velocity representation of ν, and an output velocity representation that\n", "# corresponds to the velocity representation of the desired 6D velocity.\n", "\n", "# You can use the following context manager to easily switch between representations.\n", "with data.switch_velocity_representation(VelRepr.Body):\n", "\n", " # Body-fixed generalized velocity.\n", " B_ν = data.generalized_velocity\n", "\n", " # Free-floating Jacobian accepting a body-fixed generalized velocity and\n", " # returning an inertial-fixed link velocity.\n", " W_J_WL_B = js.link.jacobian(\n", " model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Inertial\n", " )\n", "\n", "# Now the following relation should hold.\n", "assert jnp.allclose(W_v_WL, W_J_WL_B @ B_ν)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SSoziCShtwZ9" }, "outputs": [], "source": [ "# Find the index of a frame.\n", "frame_name = \"l_foot_front\"\n", "frame_index = js.frame.name_to_idx(model=model, frame_name=frame_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fVp_xP_1twZ9", "outputId": "cfaa0569-d768-4708-c98c-a5867c056d04" }, "outputs": [], "source": [ "# @title Frame Pose\n", "\n", "# Compute its pose w.r.t. the world frame through forward kinematics.\n", "W_H_F = js.frame.transform(model=model, data=data, frame_index=frame_index)\n", "\n", "print(f\"Transform of '{frame_name}': shape={W_H_F.shape}\\n{W_H_F}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QqaqxneEFYiW" }, "outputs": [], "source": [ "# @title Frame 6D Velocity\n", "\n", "# JaxSim allows to select the so-called representation of the frame velocity.\n", "F_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Body)\n", "FW_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Mixed)\n", "W_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Inertial)\n", "\n", "print(f\"Body-fixed velocity F_v_WF={F_v_WF}\")\n", "print(f\"Mixed velocity: FW_v_WF={FW_v_WF}\")\n", "print(f\"Inertial-fixed velocity: W_v_WF={W_v_WF}\")\n", "\n", "# These can also be computed passing through the frame free-floating Jacobian.\n", "# This type of Jacobian has a input velocity representation that corresponds\n", "# the velocity representation of ν, and an output velocity representation that\n", "# corresponds to the velocity representation of the desired 6D velocity.\n", "\n", "# You can use the following context manager to easily switch between representations.\n", "with data.switch_velocity_representation(VelRepr.Body):\n", "\n", " # Body-fixed generalized velocity.\n", " B_ν = data.generalized_velocity\n", "\n", " # Free-floating Jacobian accepting a body-fixed generalized velocity and\n", " # returning an inertial-fixed link velocity.\n", " W_J_WF_B = js.frame.jacobian(\n", " model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Inertial\n", " )\n", "\n", "# Now the following relation should hold.\n", "assert jnp.allclose(W_v_WF, W_J_WF_B @ B_ν)" ] }, { "cell_type": "markdown", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "d_vp6D74GoVZ", "outputId": "798b9283-792e-4339-b56c-df2595fac974" }, "source": [ "## Robot Dynamics\n", "\n", "JaxSim provides all the quantities involved in the equations of motion, restated here:\n", "\n", "$$\n", "M(\\mathbf{q}) \\dot{\\boldsymbol{\\nu}} + \\mathbf{h}(\\mathbf{q}, \\boldsymbol{\\nu}) = B \\boldsymbol{\\tau} + \\sum_{L_i \\in \\mathcal{L}} J_{W,L_i}^\\top(\\mathbf{q}) \\: \\mathbf{f}_i\n", ".\n", "$$\n", "\n", "Specifically, it can compute:\n", "\n", "- $M(\\mathbf{q}) \\in \\mathbb{R}^{(6+n)\\times(6+n)}$: the mass matrix.\n", "- $\\mathbf{h}(\\mathbf{q}, \\boldsymbol{\\nu}) \\in \\mathbb{R}^{6+n}$: the vector of bias forces.\n", "- $B \\in \\mathbb{R}^{(6+n) \\times n}$ the joint selector matrix.\n", "- $J_{W,L} \\in \\mathbb{R}^{6 \\times (6+n)}$ the Jacobian of link $L$.\n", "\n", "Often, for convenience, link Jacobians are stacked together. Since JaxSim efficiently computes the Jacobians for all links, using the stacked version is recommended when needed:\n", "\n", "$$\n", "M(\\mathbf{q}) \\dot{\\boldsymbol{\\nu}} + \\mathbf{h}(\\mathbf{q}, \\boldsymbol{\\nu}) = B \\boldsymbol{\\tau} + J_{W,\\mathcal{L}}^\\top(\\mathbf{q}) \\: \\mathbf{f}_\\mathcal{L}\n", ".\n", "$$\n", "\n", "Furthermore, there are applications that require unpacking the vector of bias forces as follow:\n", "\n", "$$\n", "\\mathbf{h}(\\mathbf{q}, \\boldsymbol{\\nu}) = C(\\mathbf{q}, \\boldsymbol{\\nu}) \\boldsymbol{\\nu} + \\mathbf{g}(\\mathbf{q})\n", ",\n", "$$\n", "\n", "where:\n", "\n", "- $\\mathbf{g}(\\mathbf{q}) \\in \\mathbb{R}^{6+n}$: the vector of gravity forces.\n", "- $C(\\mathbf{q}, \\boldsymbol{\\nu}) \\in \\mathbb{R}^{(6+n)\\times(6+n)}$: the Coriolis matrix.\n", "\n", "Here below we report the functions to compute all these quantities. Note that all quantities depend on the active velocity representation of `data`. As it was done for the link velocity, it is possible to change the representation associated to all the computed quantities by operating within the corresponding context manager. Here below we consider the default representation of data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oOKJOVfsH4Ki" }, "outputs": [], "source": [ "print(\"Velocity representation of data:\", data.velocity_representation, \"\\n\")\n", "\n", "# Compute the mass matrix.\n", "M = js.model.free_floating_mass_matrix(model=model, data=data)\n", "print(f\"M: shape={M.shape}\")\n", "\n", "# Compute the vector of bias forces.\n", "h = js.model.free_floating_bias_forces(model=model, data=data)\n", "print(f\"h: shape={h.shape}\")\n", "\n", "# Compute the vector of gravity forces.\n", "g = js.model.free_floating_gravity_forces(model=model, data=data)\n", "print(f\"g: shape={g.shape}\")\n", "\n", "# Compute the Coriolis matrix.\n", "C = js.model.free_floating_coriolis_matrix(model=model, data=data)\n", "print(f\"C: shape={C.shape}\")\n", "\n", "# Create a the joint selector matrix.\n", "B = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T\n", "print(f\"B: shape={B.shape}\")\n", "\n", "# Compute the stacked tensor of link Jacobians.\n", "J = js.model.generalized_free_floating_jacobian(model=model, data=data)\n", "print(f\"J: shape={J.shape}\")\n", "\n", "# Extract the joint forces from the references object.\n", "τ = references.joint_force_references(model=model)\n", "print(f\"τ: shape={τ.shape}\")\n", "\n", "# Extract the link forces from the references object.\n", "f_L = references.link_forces(model=model, data=data)\n", "print(f\"f_L: shape={f_L.shape}\")\n", "\n", "# The following relation should hold.\n", "assert jnp.allclose(h, C @ ν + g)" ] }, { "cell_type": "markdown", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FlNo8dNWKKtu", "outputId": "313e939b-f88f-4407-c9ee-b5b3b7443061" }, "source": [ "### Forward Dynamics\n", "\n", "$$\n", "\\dot{\\boldsymbol{\\nu}} = \\text{FD}(\\mathbf{q}, \\boldsymbol{\\nu}, \\boldsymbol{\\tau}, \\mathbf{f}_{\\mathcal{L}})\n", "$$\n", "\n", "JaxSim provides two alternative methods to compute the forward dynamics:\n", "\n", "1. Operate on the quantities of the equations of motion.\n", "2. Call the recursive Articulated Body Algorithm (ABA).\n", "\n", "The physics engine provided by JaxSim exploits the efficient calculation of the forward dynamics with ABA for simulating the trajectories of the system dynamics." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LXARuRu1Ly1K" }, "outputs": [], "source": [ "ν̇_eom = jnp.linalg.pinv(M) @ (B @ τ - h + jnp.einsum(\"l6g,l6->g\", J, f_L))\n", "\n", "v̇_WB, s̈ = js.model.forward_dynamics_aba(\n", " model=model, data=data, link_forces=f_L, joint_forces=joint_force_references\n", ")\n", "\n", "ν̇_aba = jnp.hstack([v̇_WB, s̈])\n", "print(f\"ν̇: shape={ν̇_aba.shape}\") # noqa: RUF001\n", "\n", "# The following relation should hold.\n", "assert jnp.allclose(ν̇_eom, ν̇_aba)" ] }, { "cell_type": "markdown", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "g5GOYXDnLySU", "outputId": "ad4ce77d-d06f-473a-9c32-040680d76aa5" }, "source": [ "### Inverse Dynamics\n", "\n", "$$\n", "(\\boldsymbol{\\tau}, \\, \\mathbf{f}_B) = \\text{ID}(\\mathbf{q}, \\boldsymbol{\\nu}, \\dot{\\boldsymbol{\\nu}}, \\mathbf{f}_{\\mathcal{L}})\n", "$$\n", "\n", "JaxSim offers two methods to compute inverse dynamics:\n", "\n", "- Directly use the quantities from the equations of motion.\n", "- Use the Recursive Newton-Euler Algorithm (RNEA).\n", "\n", "Unlike many other implementations, JaxSim's RNEA for floating-base systems is the true inverse of $\\text{FD}$. It also computes the 6D force applied to the base link that generates the base acceleration." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UTae5MjhaP2H" }, "outputs": [], "source": [ "f_B, τ_rnea = js.model.inverse_dynamics(\n", " model=model,\n", " data=data,\n", " base_acceleration=v̇_WB,\n", " joint_accelerations=s̈,\n", " # To check that f_B works, let's remove the force applied\n", " # to the base link from the link forces.\n", " link_forces=f_L.at[0].set(jnp.zeros(6))\n", ")\n", "\n", "print(f\"f_B: shape={f_B.shape}\")\n", "print(f\"τ_rnea: shape={τ_rnea.shape}\")\n", "\n", "# The following relations should hold.\n", "assert jnp.allclose(τ_rnea, τ)\n", "assert jnp.allclose(f_B, link_forces[0])" ] }, { "cell_type": "markdown", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "gYZ1jK1Neg1H", "outputId": "0de79770-1e18-4027-bb47-5713bc1b4a72" }, "source": [ "### Centroidal Dynamics\n", "\n", "Centroidal dynamics is a useful simplification often employed in planning and control applications. It represents the dynamics projected onto a mixed frame associated with the center of mass (CoM):\n", "\n", "$$\n", "G = G[W] = ({}^W \\mathbf{p}_{\\text{CoM}}, [W])\n", ".\n", "$$\n", "\n", "The governing equations for centroidal dynamics take into account the 6D centroidal momentum:\n", "\n", "$$\n", "{}_G \\mathbf{h} =\n", "\\begin{bmatrix}\n", "{}_G \\mathbf{h}^l \\\\ {}_G \\mathbf{h}^\\omega\n", "\\end{bmatrix} =\n", "\\begin{bmatrix}\n", "m \\, {}^W \\dot{\\mathbf{p}}_\\text{CoM} \\\\ {}_G \\mathbf{h}^\\omega\n", "\\end{bmatrix}\n", "\\in \\mathbb{R}^6\n", ".\n", "$$\n", "\n", "The equations of centroidal dynamics can be expressed as:\n", "\n", "$$\n", "{}_G \\dot{\\mathbf{h}} =\n", "m \\,\n", "\\begin{bmatrix}\n", "{}^W \\mathbf{g} \\\\ \\mathbf{0}_3\n", "\\end{bmatrix} +\n", "\\sum_{C_i \\in \\mathcal{C}} {}_G \\mathbf{X}^{C_i} \\, {}_{C_i} \\mathbf{f}_i\n", ".\n", "$$\n", "\n", "While centroidal dynamics can function independently by considering the total mass $m \\in \\mathbb{R}$ of the robot and the transformations for 6D contact forces ${}_G \\mathbf{X}^{C_i}$ corresponding to the pose ${}^G \\mathbf{H}_{C_i} \\in \\text{SE}(3)$ of the contact frames, advanced kino-dynamic methods may require a relationship between full kinematics and centroidal dynamics. This is typically achieved through the _Centroidal Momentum Matrix_ (also known as the _centroidal momentum Jacobian_):\n", "\n", "$$\n", "{}_G \\mathbf{h} = J_\\text{CMM}(\\mathbf{q}) \\, \\boldsymbol{\\nu}\n", ".\n", "$$\n", "\n", "JaxSim offers APIs to compute all these quantities (and many more) in the `jaxsim.api.com` package." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rrSfxp8lh9YZ" }, "outputs": [], "source": [ "# Number of contact points.\n", "n_cp = len(model.kin_dyn_parameters.contact_parameters.body)\n", "print(\"Number of contact points:\", n_cp, \"\\n\")\n", "\n", "# Compute the centroidal momentum.\n", "J_CMM = js.com.centroidal_momentum_jacobian(model=model, data=data)\n", "G_h = J_CMM @ ν\n", "print(f\"G_h: shape={G_h.shape}\")\n", "print(f\"J_CMM: shape={J_CMM.shape}\")\n", "\n", "# The following relation should hold.\n", "assert jnp.allclose(G_h, js.com.centroidal_momentum(model=model, data=data))\n", "\n", "# If we consider all contact points of the model as active\n", "# (discourages since they might be too many), the 6D transforms of\n", "# collidable points can be computed as follows:\n", "W_H_C = js.contact.transforms(model=model, data=data)\n", "\n", "# Compute the pose of the G frame.\n", "W_p_CoM = js.com.com_position(model=model, data=data)\n", "G_H_W = jaxsim.math.Transform.inverse(jnp.eye(4).at[0:3, 3].set(W_p_CoM))\n", "\n", "# Convert from SE(3) to the transforms for 6D forces.\n", "G_Xf_C = jax.vmap(\n", " lambda W_H_Ci: jaxsim.math.Adjoint.from_transform(\n", " transform=G_H_W @ W_H_Ci, inverse=True\n", " )\n", ")(W_H_C)\n", "print(f\"G_Xf_C: shape={G_Xf_C.shape}\")\n", "\n", "# Let's create random 3D linear forces applied to the contact points.\n", "C_fl = jax.random.uniform(\n", " minval=-10.0,\n", " maxval=10.0,\n", " shape=(n_cp, 3),\n", " key=jax.random.PRNGKey(0),\n", ")\n", "\n", "# Compute the 3D gravity vector and the total mass of the robot.\n", "m = js.model.total_mass(model=model)\n", "\n", "# The centroidal dynamics can be computed as follows.\n", "G_ḣ = 0\n", "G_ḣ += m * jnp.hstack([0, 0, model.gravity, 0, 0, 0])\n", "G_ḣ += jnp.einsum(\"c66,c6->6\", G_Xf_C, jnp.hstack([C_fl, jnp.zeros_like(C_fl)]))\n", "print(f\"G_ḣ: shape={G_ḣ.shape}\")" ] }, { "cell_type": "markdown", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Ot6HePB_twaE", "outputId": "02a6abae-257e-45ee-e9de-6a607cdbeb9a" }, "source": [ "## Contact Frames\n", "\n", "Many control and planning applications require projecting the floating-base dynamics into the contact space or computing quantities related to active contact points, such as enforcing holonomic constraints.\n", "\n", "The underlying theory for these applications becomes clearer in a mixed representation. Specifically, the position, linear velocity, and linear acceleration of contact points in their corresponding mixed frame align with the numerical derivatives of their coordinate vectors.\n", "\n", "Key methodologies in this area may involve the Delassus matrix:\n", "\n", "$$\n", "\\Psi(\\mathbf{q}) = J_{W,C}(\\mathbf{q}) \\, M(\\mathbf{q})^{-1} \\, J_{W,C}^T(\\mathbf{q})\n", "$$\n", "\n", "or the linear acceleration of a contact point:\n", "\n", "$$\n", "{}^W \\ddot{\\mathbf{p}}_C = \\frac{\\text{d} (J^l_{W,C} \\boldsymbol{\\nu})}{\\text{d}t}\n", "= \\dot{J}^l_{W,C} \\boldsymbol{\\nu} + J^l_{W,C} \\dot{\\boldsymbol{\\nu}}\n", ".\n", "$$\n", "\n", "JaxSim offers APIs to compute all these quantities (and many more) in the `jaxsim.api.contact` package." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LITRC3STliKR" }, "outputs": [], "source": [ "with (\n", " data.switch_velocity_representation(VelRepr.Mixed),\n", " references.switch_velocity_representation(VelRepr.Mixed),\n", "):\n", "\n", " # Compute the mixed generalized velocity.\n", " BW_ν = data.generalized_velocity\n", "\n", " # Compute the mixed generalized acceleration.\n", " BW_ν̇ = jnp.hstack(\n", " js.model.forward_dynamics(\n", " model=model,\n", " data=data,\n", " link_forces=references.link_forces(model=model, data=data),\n", " joint_forces=references.joint_force_references(model=model),\n", " )\n", " )\n", "\n", " # Compute the mass matrix in mixed representation.\n", " BW_M = js.model.free_floating_mass_matrix(model=model, data=data)\n", "\n", " # Compute the contact Jacobian and its derivative.\n", " Jl_WC = js.contact.jacobian(model=model, data=data)[:, 0:3, :]\n", " J̇l_WC = js.contact.jacobian_derivative(model=model, data=data)[:, 0:3, :]\n", "\n", "# Compute the Delassus matrix.\n", "Ψ = jnp.vstack(Jl_WC) @ jnp.linalg.lstsq(BW_M, jnp.vstack(Jl_WC).T)[0]\n", "print(f\"Ψ: shape={Ψ.shape}\")\n", "\n", "# Compute the transforms of the mixed frames implicitly associated\n", "# to each collidable point.\n", "W_H_C = js.contact.transforms(model=model, data=data)\n", "print(f\"W_H_C: shape={W_H_C.shape}\")\n", "\n", "# Compute the linear velocity of the collidable points.\n", "with data.switch_velocity_representation(VelRepr.Mixed):\n", " W_ṗ_B = js.contact.collidable_point_velocities(model=model, data=data)[:, 0:3]\n", " print(f\"W_ṗ_B: shape={W_ṗ_B.shape}\")\n", "\n", "# Compute the linear acceleration of the collidable points.\n", "W_p̈_C = 0\n", "W_p̈_C += jnp.einsum(\"c3g,g->c3\", J̇l_WC, BW_ν)\n", "W_p̈_C += jnp.einsum(\"c3g,g->c3\", Jl_WC, BW_ν̇)\n", "print(f\"W_p̈_C: shape={W_p̈_C.shape}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "LITRC3STliKR" }, "source": [ "## Conclusions\n", "\n", "This notebook provided an overview of the main APIs in JaxSim for its use as a multibody dynamics library. Here are a few key points to remember:\n", "\n", "- Explore all the modules in the `jaxsim.api` package to discover the full range of APIs available. Many more functionalities exist beyond what was covered in this notebook.\n", "- All APIs follow a functional approach, consistent with the JAX programming style.\n", "- This functional design allows for easy application of `jax.vmap` to execute functions in parallel on hardware accelerators.\n", "- Since the entire multibody dynamics library is built with JAX, it natively supports `jax.grad`, `jax.jacfwd`, and `jax.jacrev` transformations, enabling automatic differentiation through complex logic without additional effort.\n", "\n", "Have fun!" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuClass": "premium", "private_outputs": true, "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "comodo_jaxsim", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: examples/jaxsim_as_physics_engine.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "H-WgcgGQaTG7" }, "source": [ "# JaxSim as a hardware-accelerated parallel physics engine\n", "\n", "This notebook shows how to use the key APIs to load a robot model and simulate multiple trajectories simultaneously.\n", "\n", "\n", " \"Open\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "SgOSnrSscEkt" }, "source": [ "## Prepare the environment" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fdqvAqMDaTG9" }, "outputs": [], "source": [ "# @title Imports and setup\n", "import os\n", "import sys\n", "from IPython.display import clear_output\n", "\n", "IS_COLAB = \"google.colab\" in sys.modules\n", "\n", "# Install JAX and Gazebo\n", "if IS_COLAB:\n", " !{sys.executable} -m pip install --pre -qU jaxsim\n", " !apt install -qq lsb-release wget gnupg\n", " !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n", " !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n", " !apt -qq update\n", " !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n", "\n", " clear_output()\n", "\n", "# Set environment variable to avoid GPU out of memory errors\n", "%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n", "\n", "\n", "# ================\n", "# Notebook imports\n", "# ================\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "import jaxsim.api as js\n", "from jaxsim import logging\n", "import pathlib\n", "\n", "logging.set_logging_level(logging.LoggingLevel.WARNING)\n", "print(f\"Running on {jax.devices()}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "NqjuZKvOaTG_" }, "source": [ "## Prepare the simulation\n", "\n", "JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. In this example, we will load the [ergoCub][ergocub] model urdf.\n", "\n", "[sdformat]: http://sdformat.org/\n", "[urdf]: http://wiki.ros.org/urdf/\n", "[ergocub]: https://ergocub.eu/\n", "[rod]: https://github.com/gbionics/rod\n", "\n", "### Create the model and its data\n", " To define a simulation we need two main objects:\n", "\n", "- `model`: an object that defines the dynamics of the system.\n", "- `data`: an object that contains the state of the system.\n", "\n", "\n", "The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\n", "To see the advanced usage, check the advanced example, where you will see how to pass explicitly an integrator class and state to the `model` object and how to change the contact model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create the model " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "etQ577cFaTHA" }, "outputs": [], "source": [ "# Create the JaxSim model.\n", "try:\n", " os.environ[\"ROBOT_DESCRIPTION_COMMIT\"] = \"v0.7.7\"\n", "\n", " import robot_descriptions.ergocub_description\n", "\n", "finally:\n", " _ = os.environ.pop(\"ROBOT_DESCRIPTION_COMMIT\", None)\n", "\n", "model_description_path = pathlib.Path(\n", " robot_descriptions.ergocub_description.URDF_PATH.replace(\n", " \"ergoCubSN002\", \"ergoCubSN001\"\n", " )\n", ")\n", "\n", "clear_output()\n", "\n", "full_model = js.model.JaxSimModel.build_from_model_description(\n", " model_description=model_description_path,\n", " time_step=0.0001,\n", " is_urdf=True\n", ")\n", "\n", "joints_list = tuple(('l_shoulder_pitch', 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow',\n", " 'r_shoulder_pitch', 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow',\n", " 'l_hip_pitch', 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',\n", " 'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll'))\n", "\n", "model = js.model.reduce(\n", " model=full_model,\n", " considered_joints=joints_list\n", ")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create the data object \n", "\n", "The data object is never changed by reference. Anytime you call a method aimed at modifying data, like `reset_base_position`, a new data object will be returned with the updated attributes while the original data will not be changed." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create the data of a single model.\n", "data = js.data.JaxSimModelData.build(model=model, base_position=jnp.array([0.0, 0.0, 1.0]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Simulation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create a random JAX key.\n", "\n", "key = jax.random.PRNGKey(seed=0)\n", "\n", "# Initialize the simulated time.\n", "T = jnp.arange(start=0, stop=0.3, step=model.time_step)\n", "\n", "# Simulate\n", "for _t in T:\n", " data = js.model.step(\n", " model=model,\n", " data=data,\n", " link_forces=None,\n", " joint_force_references=None,\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Vectorized simulation \n", "\n", "We will now vectorize the simulation on batched data using `jax.vmap`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# first we have to vmap the function\n", "\n", "import functools\n", "from typing import Any\n", "\n", "\n", "@jax.jit\n", "def step_single(\n", " model: js.model.JaxSimModel,\n", " data: js.data.JaxSimModelData,\n", ") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n", "\n", " # Close step over static arguments.\n", " return js.model.step(\n", " model=model,\n", " data=data,\n", " link_forces=None,\n", " joint_force_references=None,\n", " )\n", "\n", "\n", "@jax.jit\n", "@functools.partial(jax.vmap, in_axes=(None, 0))\n", "def step_parallel(\n", " model: js.model.JaxSimModel,\n", " data: js.data.JaxSimModelData,\n", ") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n", "\n", " return step_single(\n", " model=model, data=data\n", " )\n", "\n", "\n", "# Then we have to create the vector of initial state\n", "batch_size = 5\n", "data_batch_t0 = jax.vmap(\n", " lambda pos: js.data.JaxSimModelData.build(model=model, base_position=pos))(jnp.tile(jnp.array([0.0, 0.0, 1.0]), (batch_size, 1)))\n", "\n", "data = data_batch_t0\n", "for _t in T:\n", " data = step_parallel(model, data)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuClass": "premium", "private_outputs": true, "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "jaxsim", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.1" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: examples/jaxsim_as_physics_engine_advanced.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "H-WgcgGQaTG7" }, "source": [ "# JaxSim as a hardware-accelerated parallel physics engine-advanced usage\n", "\n", "JaxSim is developed to optimize synthetic data generation by sampling trajectories using hardware accelerators such as GPUs and TPUs.\n", "\n", "In this notebook, you'll learn how to use the key APIs to load a simple robot model (a sphere) and simulate multiple trajectories in parallel on GPUs.\n", "\n", "\n", " \"Open\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "SgOSnrSscEkt" }, "source": [ "## Prepare the environment" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fdqvAqMDaTG9" }, "outputs": [], "source": [ "# @title Imports and setup\n", "import sys\n", "from IPython.display import clear_output\n", "\n", "IS_COLAB = \"google.colab\" in sys.modules\n", "\n", "# Install JAX and Gazebo\n", "if IS_COLAB:\n", " !{sys.executable} -m pip install --pre -qU jaxsim[viz]\n", " !apt install -qq lsb-release wget gnupg\n", " !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n", " !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n", " !apt -qq update\n", " !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n", "\n", " clear_output()\n", "\n", "# Set environment variable to avoid GPU out of memory errors\n", "%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\n", "\n", "# ================\n", "# Notebook imports\n", "# ================\n", "\n", "import os\n", "\n", "if sys.platform == 'darwin':\n", " os.environ[\"MUJOCO_GL\"] = \"glfw\"\n", "else:\n", " os.environ[\"MUJOCO_GL\"] = \"egl\"\n", "\n", "import jax\n", "\n", "import jax.numpy as jnp\n", "import jaxsim.api as js\n", "import rod\n", "from jaxsim import logging\n", "from rod.builder.primitives import SphereBuilder\n", "\n", "logging.set_logging_level(logging.LoggingLevel.WARNING)\n", "print(f\"Running on {jax.devices()}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "QtCCUhdpdGFH" }, "source": [ "## Prepare the simulation\n", "\n", "JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. This is done using the [`gbionics/rod`][rod] library, which processes these formats.\n", "\n", "The `rod` library also allows creating in-memory models that can be serialized to SDF or URDF. We'll use this functionality to build a sphere model, which will later be used to create the JaxSim model.\n", "\n", "[sdformat]: http://sdformat.org/\n", "[urdf]: http://wiki.ros.org/urdf/\n", "[rod]: https://github.com/gbionics/rod" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "0emoMQhCaTG_" }, "outputs": [], "source": [ "# @title Create the model description of a sphere\n", "\n", "# Create a SDF model.\n", "# The builder takes care to compute the right inertia tensor for you.\n", "rod_sdf = rod.Sdf(\n", " version=\"1.7\",\n", " model=SphereBuilder(radius=0.10, mass=1.0, name=\"sphere\")\n", " .build_model()\n", " .add_link()\n", " .add_inertial()\n", " .add_visual()\n", " .add_collision()\n", " .build(),\n", ")\n", "\n", "# Rod allows to update the frames w.r.t. the poses are expressed.\n", "rod_sdf.model.switch_frame_convention(\n", " frame_convention=rod.FrameConvention.Urdf, explicit_frames=True\n", ")\n", "\n", "# Serialize the model to a SDF string.\n", "model_sdf_string = rod_sdf.serialize(pretty=True)\n", "print(model_sdf_string)\n", "\n", "# JaxSim currently only supports collisions between points attached to bodies\n", "# and a ground surface modeled as a heightmap sampled from a smooth function.\n", "# While this approach is universal as it applies to generic meshes, the number\n", "# of considered points greatly affects the performance. Spheres, by default,\n", "# are discretized with 250 points. It's too much for this simple example.\n", "# This number can be decreased with the following environment variable.\n", "os.environ[\"JAXSIM_COLLISION_SPHERE_POINTS\"] = \"50\"" ] }, { "cell_type": "markdown", "metadata": { "id": "NqjuZKvOaTG_" }, "source": [ "### Create the model and its data\n", "\n", "JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\n", "\n", "- `model`: an object that defines the dynamics of the system.\n", "- `data`: an object that contains the state of the system.\n", "- `integrator` *(Optional)*: an object that defines the integration method.\n", "- `integrator_metadata` *(Optional)*: an object that contains the state of the integrator.\n", "\n", "The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\n", "In this example, we will explicitly pass an integrator class to the `model` object and we will use the default `SoftContacts` contact model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "etQ577cFaTHA" }, "outputs": [], "source": [ "# Create the JaxSim model.\n", "# This is shared among all the parallel instances.\n", "model = js.model.JaxSimModel.build_from_model_description(\n", " model_description=model_sdf_string,\n", " time_step=0.001,\n", ")\n", "\n", "# Create the data of a single model.\n", "# We will create a vectorized instance later.\n", "data_single = js.data.JaxSimModelData.zero(model=model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "o86Teq5piVGj" }, "outputs": [], "source": [ "# Initialize the simulated time.\n", "T = jnp.arange(start=0, stop=1.0, step=model.time_step)" ] }, { "cell_type": "markdown", "metadata": { "id": "V6IeD2B3m4F0" }, "source": [ "## Sample a batch of trajectories in parallel\n", "\n", "With the provided resources, you can step through an open-loop trajectory on a single model using `jaxsim.api.model.step`.\n", "\n", "In this notebook, we'll focus on running parallel steps. We'll use JAX's automatic vectorization to apply the step function to batched data.\n", "\n", "Note that these parallel simulations are independent — models don't interact, so there's no need to avoid initial collisions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vtEn0aIzr_2j" }, "outputs": [], "source": [ "# @title Generate batched initial data\n", "\n", "# Create a random JAX key.\n", "key = jax.random.PRNGKey(seed=0)\n", "\n", "# Split subkeys for sampling random initial data.\n", "batch_size = 9\n", "row_length = int(jnp.sqrt(batch_size))\n", "row_dist = 0.3 * row_length\n", "key, *subkeys = jax.random.split(key=key, num=batch_size + 1)\n", "\n", "# Create the batched data by sampling the height from [0.5, 0.6] meters.\n", "data_batch_t0 = jax.vmap(\n", " lambda key: js.data.random_model_data(\n", " model=model,\n", " key=key,\n", " base_pos_bounds=([0, 0, 0.3], [0, 0, 1.2]),\n", " base_vel_lin_bounds=(0, 0),\n", " base_vel_ang_bounds=(0, 0),\n", " )\n", ")(jnp.vstack(subkeys))\n", "\n", "x, y = jnp.meshgrid(\n", " jnp.linspace(-row_dist, row_dist, num=row_length),\n", " jnp.linspace(-row_dist, row_dist, num=row_length),\n", ")\n", "xy_coordinate = jnp.stack([x.flatten(), y.flatten()], axis=-1)\n", "\n", "# Reset the x and y position to a grid.\n", "data_batch_t0 = data_batch_t0.replace(\n", " model=model,\n", " base_position=data_batch_t0.base_position.at[:, :2].set(xy_coordinate),\n", ")\n", "\n", "print(\"W_p_B(t0)=\\n\", data_batch_t0.base_position[0:10])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0tQPfsl6uxHm" }, "outputs": [], "source": [ "# @title Create parallel step function\n", "\n", "import functools\n", "from typing import Any\n", "\n", "\n", "@jax.jit\n", "def step_single(\n", " model: js.model.JaxSimModel,\n", " data: js.data.JaxSimModelData,\n", ") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n", "\n", " # Close step over static arguments.\n", " return js.model.step(\n", " model=model,\n", " data=data,\n", " link_forces=None,\n", " joint_force_references=None,\n", " )\n", "\n", "\n", "@jax.jit\n", "@functools.partial(jax.vmap, in_axes=(None, 0))\n", "def step_parallel(\n", " model: js.model.JaxSimModel,\n", " data: js.data.JaxSimModelData,\n", ") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\n", "\n", " return step_single(\n", " model=model, data=data\n", " )\n", "\n", "\n", "# The first run will be slow since JAX needs to JIT-compile the functions.\n", "_ = step_single(model, data_single)\n", "_ = step_parallel(model, data_batch_t0)\n", "\n", "# Benchmark the execution of a single step.\n", "print(\"\\nSingle simulation step:\")\n", "%timeit step_single(model, data_single)\n", "\n", "# On hardware accelerators, there's a range of batch_size values where\n", "# increasing the number of parallel instances doesn't affect computation time.\n", "# This range depends on the GPU/TPU specifications.\n", "print(f\"\\nParallel simulation steps (batch_size={batch_size} on {jax.devices()[0]}):\")\n", "%timeit step_parallel(model, data_batch_t0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VNwzT2JQ1n15" }, "outputs": [], "source": [ "# @title Run parallel simulation\n", "\n", "data = data_batch_t0\n", "data_trajectory_list = []\n", "\n", "for _ in T:\n", "\n", " data = step_parallel(model, data)\n", " data_trajectory_list.append(data)" ] }, { "cell_type": "markdown", "metadata": { "id": "Y6n720Cr3G44" }, "source": [ "## Visualize trajectory" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BLPODyKr3Lyg" }, "outputs": [], "source": [ "# Convert a list of PyTrees to a batched PyTree.\n", "# This operation is called 'tree transpose' in JAX.\n", "data_trajectory = jax.tree.map(lambda *leafs: jnp.stack(leafs), *data_trajectory_list)\n", "\n", "print(f\"W_p_B: shape={data_trajectory.base_position.shape}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-jxJXy5r3RMt" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "\n", "plt.plot(T, data_trajectory.base_position[:, :, 2])\n", "plt.grid(True)\n", "plt.xlabel(\"Time [s]\")\n", "plt.ylabel(\"Height [m]\")\n", "plt.title(\"Height trajectory of the sphere\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import jaxsim.mujoco\n", "\n", "mjcf_string, assets = jaxsim.mujoco.ModelToMjcf.convert(\n", " model.built_from,\n", " cameras=jaxsim.mujoco.loaders.MujocoCamera.build_from_target_view(\n", " camera_name=\"sphere_cam\",\n", " lookat=[0, 0, 0.3],\n", " distance=4,\n", " azimuth=150,\n", " elevation=-10,\n", " ),\n", ")\n", "\n", "# Create a helper for each parallel instance.\n", "mj_model_helpers = [\n", " jaxsim.mujoco.MujocoModelHelper.build_from_xml(\n", " mjcf_description=mjcf_string, assets=assets\n", " )\n", " for _ in range(batch_size)\n", "]\n", "\n", "# Create the video recorder.\n", "recorder = jaxsim.mujoco.MujocoVideoRecorder(\n", " model=mj_model_helpers[0].model,\n", " data=[helper.data for helper in mj_model_helpers],\n", " fps=int(1 / model.time_step),\n", " width=320 * 2,\n", " height=240 * 2,\n", ")\n", "\n", "for data_t in data_trajectory_list:\n", "\n", " for helper, base_position, base_quaternion, joint_position in zip(\n", " mj_model_helpers,\n", " data_t.base_position,\n", " data_t.base_orientation,\n", " data_t.joint_positions,\n", " strict=True,\n", " ):\n", " helper.set_base_position(position=base_position)\n", " helper.set_base_orientation(orientation=base_quaternion)\n", "\n", " if model.dofs() > 0:\n", " helper.set_joint_positions(\n", " positions=joint_position, joint_names=model.joint_names()\n", " )\n", "\n", " # Record a new video frame.\n", " recorder.record_frame(camera_name=\"sphere_cam\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import mediapy as media\n", "\n", "media.show_video(recorder.frames, fps=recorder.fps)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuClass": "premium", "private_outputs": true, "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "jaxpypi", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.1" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: examples/jaxsim_for_robot_controllers.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "EhPy6FgiZH4d" }, "source": [ "# JaxSim for developing closed-loop robot controllers\n", "\n", "Originally developed as a **hardware-accelerated physics engine**, JaxSim has expanded its capabilities to become a full-featured **JAX-based multibody dynamics library**.\n", "\n", "In this notebook, you'll explore how to combine these two core features. Specifically, you'll learn how to load a robot model and design a model-based controller for closed-loop simulations.\n", "\n", "\n", " \"Open\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vsf1AlxdZH4f" }, "outputs": [], "source": [ "# @title Prepare the environment\n", "from IPython.display import clear_output\n", "import sys\n", "\n", "IS_COLAB = \"google.colab\" in sys.modules\n", "\n", "# Install JAX, sdformat, and other notebook dependencies.\n", "if IS_COLAB:\n", " !{sys.executable} -m pip install --pre -qU jaxsim[viz]\n", " !apt install -qq lsb-release wget gnupg\n", " !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\n", " !echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\n", " !apt -qq update\n", " !apt install -qq --no-install-recommends libsdformat13 gz-tools2\n", "\n", " clear_output()\n", "\n", "# ================\n", "# Notebook imports\n", "# ================\n", "\n", "import os\n", "\n", "if sys.platform == 'darwin':\n", " os.environ[\"MUJOCO_GL\"] = \"glfw\"\n", "else:\n", " os.environ[\"MUJOCO_GL\"] = \"egl\"\n", "\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "import jaxsim.mujoco\n", "from jaxsim import logging\n", "\n", "logging.set_logging_level(logging.LoggingLevel.WARNING)\n", "print(f\"Running on {jax.devices()}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "kN-b9nOsZH4g" }, "source": [ "We will use a simple cartpole model for this example. The cartpole model is a 2D model with a cart that can move horizontally and a pole that can rotate around the cart. The state of the cartpole is given by the position of the cart, the angle of the pole, the velocity of the cart, and the angular velocity of the pole. The control input is the horizontal force applied to the cart." ] }, { "cell_type": "markdown", "metadata": { "id": "5aLqrZDqR5LA" }, "source": [ "## Prepare the simulation\n", "\n", "JaxSim supports loading robot models from both [SDF][sdformat] and [URDF][urdf] files, utilizing the [`gbionics/rod`][rod] library for processing these formats.\n", "\n", "The `rod` library library can read URDF files and validates them internally using [`gazebosim/sdformat`][sdformat_github]. In this example, we'll load a cart-pole model, which will be used to create the JaxSim simulation model.\n", "\n", "[sdformat]: http://sdformat.org/\n", "[urdf]: http://wiki.ros.org/urdf/\n", "[rod]: https://github.com/gbionics/rod\n", "[sdformat_github]: https://github.com/gazebosim/sdformat" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.path.abspath(\"\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PZM7hEvFZH4h" }, "outputs": [], "source": [ "# @title Load the URDF model\n", "import pathlib\n", "import urllib\n", "\n", "# Retrieve the file\n", "url = \"https://raw.githubusercontent.com/gbionics/jaxsim/refs/heads/main/examples/assets/cartpole.urdf\"\n", "model_path, _ = urllib.request.urlretrieve(url)\n", "model_urdf_string = pathlib.Path(model_path).read_text()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "M5XsKehvZH4j" }, "outputs": [], "source": [ "# @title Create the model and its data\n", "\n", "import jaxsim.api as js\n", "\n", "# Create the model from the model description.\n", "model = js.model.JaxSimModel.build_from_model_description(\n", " model_description=model_urdf_string,\n", " time_step=0.010,\n", ")\n", "\n", "# Create the data storing the simulation state.\n", "data_zero = js.data.JaxSimModelData.zero(model=model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jk9csR5ETgn1" }, "outputs": [], "source": [ "# @title Define simulation parameters\n", "\n", "# Initialize the simulated time.\n", "T = jnp.arange(start=0, stop=5.0, step=model.time_step)" ] }, { "cell_type": "markdown", "metadata": { "id": "bo6Ke5nAWL-S" }, "source": [ "## Prepare the MuJoCo renderer\n", "\n", "For visualization purpose, we use the passive viewer of the MuJoCo simulator. It allows to either open an interactive windows when used locally or record a video when used in notebooks." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "j1_I2i5TZH4n" }, "outputs": [], "source": [ "# Create the MJCF resources from the URDF.\n", "mjcf_string, assets = jaxsim.mujoco.UrdfToMjcf.convert(\n", " urdf=model.built_from,\n", " # Create the camera used by the recorder.\n", " cameras=jaxsim.mujoco.loaders.MujocoCamera.build_from_target_view(\n", " camera_name=\"cartpole_camera\",\n", " lookat=js.link.com_position(\n", " model=model,\n", " data=data_zero,\n", " link_index=js.link.name_to_idx(model=model, link_name=\"cart\"),\n", " in_link_frame=False,\n", " ),\n", " distance=3,\n", " azimuth=150,\n", " elevation=-10,\n", " ),\n", ")\n", "\n", "# Create a helper to operate on the MuJoCo model and data.\n", "mj_model_helper = jaxsim.mujoco.MujocoModelHelper.build_from_xml(\n", " mjcf_description=mjcf_string, assets=assets\n", ")\n", "\n", "# Create the video recorder.\n", "recorder = jaxsim.mujoco.MujocoVideoRecorder(\n", " model=mj_model_helper.model,\n", " data=mj_model_helper.data,\n", " fps=int(1 / model.time_step),\n", " width=320 * 2,\n", " height=240 * 2,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "DpRvvGujZH4o" }, "source": [ "## Open-loop simulation\n", "\n", "Now, let's run a simulation to demonstrate the open-loop dynamics of the system." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gSWzcsKWZH4p" }, "outputs": [], "source": [ "import mediapy as media\n", "\n", "\n", "# Create a random joint position.\n", "# For a random full state, you can use jaxsim.api.data.random_model_data.\n", "random_joint_positions = jax.random.uniform(\n", " minval=-1.0,\n", " maxval=1.0,\n", " shape=(model.dofs(),),\n", " key=jax.random.PRNGKey(0),\n", ")\n", "\n", "# Reset the state to the random joint positions.\n", "data = js.data.JaxSimModelData.build(model=model, joint_positions=random_joint_positions)\n", "\n", "for _ in T:\n", "\n", " # Step the JaxSim simulation.\n", " data = js.model.step(\n", " model=model,\n", " data=data,\n", " joint_force_references=None,\n", " link_forces=None,\n", " )\n", "\n", " # Update the MuJoCo data.\n", " mj_model_helper.set_joint_positions(\n", " positions=data.joint_positions, joint_names=model.joint_names()\n", " )\n", "\n", " # Record a new video frame.\n", " recorder.record_frame(camera_name=\"cartpole_camera\")\n", "\n", "\n", "# Play the video.\n", "media.show_video(recorder.frames, fps=recorder.fps)\n", "recorder.frames = []" ] }, { "cell_type": "markdown", "metadata": { "id": "j1rguK3UZH4p" }, "source": [ "## Closed-loop simulation\n", "\n", "Next, let's design a simple computed torque controller. The equations of motion for the cart-pole system are given by:\n", "\n", "$$\n", "M_{ss}(\\mathbf{s}) \\, \\ddot{\\mathbf{s}} + \\mathbf{h}_s(\\mathbf{s}, \\dot{\\mathbf{s}}) = \\boldsymbol{\\tau}\n", ",\n", "$$\n", "\n", "where:\n", "\n", "- $\\mathbf{s} \\in \\mathbb{R}^n$ are the joint positions.\n", "- $\\dot{\\mathbf{s}} \\in \\mathbb{R}^n$ are the joint velocities.\n", "- $\\ddot{\\mathbf{s}} \\in \\mathbb{R}^n$ are the joint accelerations.\n", "- $\\boldsymbol{\\tau} \\in \\mathbb{R}^n$ are the joint torques.\n", "- $M_{ss} \\in \\mathbb{R}^{n \\times n}$ is the mass matrix.\n", "- $\\mathbf{h}_s \\in \\mathbb{R}^n$ is the vector of bias forces.\n", "\n", "JaxSim computes these quantities for floating-base systems, so we specifically focus on the joint-related portions by marking them with subscripts.\n", "\n", "Since no external forces or joint friction are present, we can extend a PD controller with a feed-forward term that includes gravity compensation:\n", "\n", "$$\n", "\\begin{cases}\n", "\\boldsymbol{\\tau} &= M_{ss} \\, \\ddot{\\mathbf{s}}^* + \\mathbf{h}_s \\\\\n", "\\ddot{\\mathbf{s}}^* &= \\ddot{\\mathbf{s}}^\\text{des} - k_p(\\mathbf{s} - \\mathbf{s}^{\\text{des}}) - k_d(\\mathbf{s}^{\\text{des}} - \\dot{\\mathbf{s}}^{\\text{des}})\n", "\\end{cases}\n", "\\quad\n", ",\n", "$$\n", "\n", "where $\\tilde{\\mathbf{s}} = \\left(\\mathbf{s} - \\mathbf{s}^\\text{des}\\right)$ is the joint position error.\n", "\n", "With this control law, the closed-loop system dynamics simplifies to:\n", "\n", "$$\n", "\\ddot{\\tilde{\\mathbf{s}}} = -k_p \\tilde{\\mathbf{s}} - k_d \\dot{\\tilde{\\mathbf{s}}}\n", ",\n", "$$\n", "\n", "which converges asymptotically to zero, ensuring stability." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rfMTCMyGZH4q" }, "outputs": [], "source": [ "# @title Create the computed torque controller\n", "\n", "# Define the PD gains\n", "kp = 10.0\n", "kd = 6.0\n", "\n", "\n", "def computed_torque_controller(\n", " data: js.data.JaxSimModelData,\n", " s_des: jax.Array,\n", " s_dot_des: jax.Array,\n", ") -> jax.Array:\n", "\n", " # Compute the gravity compensation term.\n", " hs = js.model.free_floating_bias_forces(model=model, data=data)[6:]\n", "\n", " # Compute the joint-related portion of the floating-base mass matrix.\n", " Mss = js.model.free_floating_mass_matrix(model=model, data=data)[6:, 6:]\n", "\n", " # Get the current joint positions and velocities.\n", " s = data.joint_positions\n", " ṡ = data.joint_velocities\n", "\n", " # Compute the actuated joint torques.\n", " s_star = -kp * (s - s_des) - kd * (ṡ - s_dot_des)\n", " τ = Mss @ s_star + hs\n", "\n", " return τ" ] }, { "cell_type": "markdown", "metadata": { "id": "ERAUisywZH4q" }, "source": [ "Now, we can use the `pd_controller` function to compute the torque to apply to the cartpole. Our aim is to stabilize the cartpole in the upright position, so we set the desired position `q_d` to 0 and the desired velocity `q_dot_d` to 0." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8YmDdGDVZH4q" }, "outputs": [], "source": [ "# @title Run the simulation\n", "\n", "# Initialize the data.\n", "\n", "# Set the joint positions.\n", "data = js.data.JaxSimModelData.build(model=model, joint_positions=jnp.array([-0.25, jnp.deg2rad(160)]), joint_velocities=jnp.array([3.00, jnp.deg2rad(10) / model.time_step]))\n", "\n", "for _ in T:\n", "\n", " # Get the actuated torques from the computed torque controller.\n", " τ = computed_torque_controller(\n", " data=data,\n", " s_des=jnp.array([0.0, 0.0]),\n", " s_dot_des=jnp.array([0.0, 0.0]),\n", " )\n", "\n", " # Step the JaxSim simulation.\n", " data = js.model.step(\n", " model=model,\n", " data=data,\n", " joint_force_references=τ,\n", " )\n", "\n", " # Update the MuJoCo data.\n", " mj_model_helper.set_joint_positions(\n", " positions=data.joint_positions, joint_names=model.joint_names()\n", " )\n", "\n", " # Record a new video frame.\n", " recorder.record_frame(camera_name=\"cartpole_camera\")\n", "\n", "media.show_video(recorder.frames, fps=recorder.fps)\n", "recorder.frames = []" ] }, { "cell_type": "markdown", "metadata": { "id": "sZ76QqeWeMQz" }, "source": [ "## Conclusions\n", "\n", "In this notebook, we explored how to use JaxSim for developing a closed-loop controller for a robot model. Key takeaways include:\n", "\n", "- We performed an open-loop simulation to understand the dynamics of the system without control.\n", "- We implemented a computed torque controller with PD feedback and a feed-forward gravity compensation term, enabling the stabilization of the system by controlling joint torques.\n", "- The closed-loop simulation can leverage hardware acceleration on GPUs and TPUs, with the ability to use `jax.vmap` for parallel sampling through automatic vectorization.\n", "\n", "JaxSim's closed-loop support can be extended to more advanced, model-based reactive controllers and planners for trajectory optimization. To explore optimization-based methods, consider the following JAX-based projects for hardware-accelerated control and planning:\n", "\n", "- [`deepmind/optax`](https://github.com/google-deepmind/optax)\n", "- [`google/jaxopt`](https://github.com/google/jaxopt)\n", "- [`patrick-kidger/lineax`](https://github.com/patrick-kidger/lineax)\n", "- [`patrick-kidger/optimistix`](https://github.com/patrick-kidger/optimistix)\n", "- [`kevin-tracy/qpax`](https://github.com/kevin-tracy/qpax)\n", "\n", "Additionally, if your controllers or planners require the derivatives of the dynamics with respect to the state or inputs, you can obtain them using automatic differentiation directly through JaxSim's API." ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuClass": "premium", "private_outputs": true, "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "comodo_jaxsim", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: pyproject.toml ================================================ [project] name = "jaxsim" dynamic = ["version"] requires-python = ">= 3.10" description = "A differentiable physics engine and multibody dynamics library for control and robot learning." authors = [ { name = "Diego Ferigo", email = "dgferigo@gmail.com" }, { name = "Filippo Luca Ferretti", email = "filippoluca.ferretti@outlook.com" }, ] maintainers = [ { name = "Filippo Luca Ferretti", email = "filippo.ferretti@outlook.com" }, ] license = "BSD-3-Clause" license-files = ["LICENSE"] keywords = [ "physics", "physics engine", "jax", "rigid body dynamics", "featherstone", "reinforcement learning", "robot", "robotics", "sdf", "urdf", ] classifiers = [ "Development Status :: 4 - Beta", "Framework :: Robot Framework", "Intended Audience :: Developers", "Intended Audience :: Science/Research", "Operating System :: POSIX :: Linux", "Operating System :: MacOS", "Operating System :: Microsoft", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.14", "Programming Language :: Python :: Implementation :: CPython", "Topic :: Games/Entertainment :: Simulation", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Physics", "Topic :: Software Development", ] dependencies = [ "coloredlogs", "jax >= 0.4.34", "jaxlib >= 0.4.34", "jaxlie >= 1.3.0", "jax_dataclasses >= 1.4.0", "pptree", "optax >= 0.2.3", "qpax", "rod >= 0.4.1", "typing_extensions ; python_version < '3.12'", "trimesh", ] [project.optional-dependencies] testing = [ "chex >= 0.1.91", "idyntree >= 12.2.1", "pytest >=6.0", "pytest-benchmark", "pytest-icdiff", "pytest-xdist", "robot-descriptions >= 1.16.0", "icub-models", ] viz = [ "lxml", "mediapy", "mujoco >= 3.0.0", "scipy >= 1.14.0", ] all = [ "jaxsim[testing,viz]", ] [project.readme] file = "README.md" content-type = "text/markdown" [project.urls] Changelog = "https://github.com/gbionics/jaxsim/releases" Documentation = "https://jaxsim.readthedocs.io" Source = "https://github.com/gbionics/jaxsim" Tracker = "https://github.com/gbionics/jaxsim/issues" # =========== # Build tools # =========== [build-system] build-backend = "hatchling.build" requires = [ "hatchling", "hatch-vcs", ] [tool.hatch.version] source = "vcs" raw-options = { local_scheme = "dirty-tag" } [tool.hatch.build.targets.wheel] packages = ["src/jaxsim"] [tool.hatch.build.hooks.vcs] version-file = "src/jaxsim/_version.py" # ================= # Style and testing # ================= [tool.black] line-length = 88 [tool.isort] multi_line_output = 3 profile = "black" [tool.pytest.ini_options] addopts = "-rsxX -v --strict-markers --benchmark-skip --benchmark-warmup=ON" minversion = "6.0" testpaths = [ "tests", ] # ================== # Ruff configuration # ================== [tool.ruff] exclude = [ ".git", ".pixi", ".pytest_cache", ".ruff_cache", ".idea", ".vscode", ".devcontainer", "__pycache__", ] preview = true [tool.ruff.lint] # https://docs.astral.sh/ruff/rules/ select = [ "B", "D", "E", "F", "I", "W", "RUF", "UP", "YTT", ] ignore = [ "B008", # Function call in default argument "B024", # Abstract base class without abstract methods "D100", # Missing docstring in public module "D104", # Missing docstring in public package "D105", # Missing docstring in magic method "D200", # One-line docstring should fit on one line with quotes "D202", # No blank lines allowed after function docstring "D203", # Incorrect blank line before class "D205", # 1 blank line required between summary line and description "D212", # Multi-line docstring summary should start at the first line "D411", # Missing blank line before section "D413", # Missing blank line after last section "E402", # Module level import not at top of file "E501", # Line too long "E731", # Do not assign a `lambda` expression, use a `def` "E741", # Ambiguous variable name "I001", # Import block is unsorted or unformatted "RUF003", # Ambiguous unicode character in comment ] [tool.ruff.lint.per-file-ignores] # Ignore `E402` (import violations) in all `__init__.py` files "**/{tests,docs,tools}/*" = ["E402"] "**/{tests,examples}/*" = ["B007", "D100", "D102", "D103"] "__init__.py" = ["F401", "RUF067"] "docs/conf.py" = ["F401"] "src/jaxsim/exceptions.py" = ["D401"] "src/jaxsim/logging.py" = ["D101", "D103"] # ================== # Pixi configuration # ================== [tool.pixi.workspace] channels = ["conda-forge"] platforms = ["linux-64", "linux-aarch64", "osx-arm64", "osx-64"] requires-pixi = ">=0.39.0" preview = ["pixi-build"] [tool.pixi.environments] # We resolve only two groups: cpu and gpu. # Then, multiple environments can be created from these groups. default = { features = ["test", "examples"] } gpu = { features = ["test", "examples", "gpu"] } # --------------- # feature.default # --------------- # Dependencies from conda-forge. [tool.pixi.dependencies] # # Matching `project.dependencies`. # coloredlogs = "*" jax = "*" jaxlib = "*" jaxlie = "*" jax-dataclasses = "*" pptree = "*" optax = "*" qpax = "*" rod = ">=0.4.1" trimesh = "*" typing_extensions = "*" # # Optional dependencies. # lxml = "*" mediapy = "*" mujoco = "*" scipy = "*" # # Additional dependencies. # pip = "*" hatchling = "*" hatch-vcs = "*" # Dependencies from PyPI. [tool.pixi.pypi-dependencies] jaxsim = { path = "./", editable = true } [tool.pixi.pypi-options] no-build-isolation = ["jaxsim"] # ------------ # feature.test # ------------ [tool.pixi.feature.test.tasks] pipcheck = "pip check" benchmark = { cmd = "pytest --benchmark-only --benchmark-warmup=ON", depends-on = ["pipcheck"] } test = { cmd = "pytest", depends-on = ["pipcheck"] } [tool.pixi.feature.test.dependencies] black-jupyter = "*" chex = ">=0.1.91" idyntree = "*" isort = "*" pre-commit = "*" pytest = "*" pytest-benchmark = "*" pytest-icdiff = "*" pytest-xdist = "*" robot_descriptions = ">=1.16.0" # ---------------- # feature.examples # ---------------- [tool.pixi.feature.examples.tasks] examples = { cmd = "jupyter notebook ./examples" } [tool.pixi.feature.examples.dependencies] notebook = "*" robot_descriptions = ">=1.16.0" # ----------- # feature.gpu # ----------- [tool.pixi.feature.gpu] platforms = ["linux-64"] system-requirements = { cuda = "13" } [tool.pixi.feature.gpu.dependencies] jaxlib = { version = "*", build = "*cuda*" } [tool.pixi.feature.gpu.tasks] test-gpu = { cmd = "pytest --gpu-only", depends-on = ["pipcheck"] } ================================================ FILE: src/jaxsim/__init__.py ================================================ from . import logging from ._version import __version__ # Follow upstream development in https://github.com/google/jax/pull/13304 def _jnp_options() -> None: import os import jax # Check if running on TPU. is_tpu = jax.devices()[0].platform == "tpu" # Check if running on Metal. is_metal = jax.devices()[0].platform == "METAL" # Enable by default 64-bit precision to get accurate physics. # Users can enforce 32-bit precision by setting the following variable to 0. use_x64 = os.environ.get("JAX_ENABLE_X64", "1") != "0" # Notify the user if unsupported 64-bit precision was enforced on TPU. if (is_tpu or is_metal) and use_x64: msg = f"64-bit precision is not allowed on {jax.devices()[0].platform.upper}. Enforcing 32bit precision." logging.warning(msg) use_x64 = False if is_metal: logging.warning( "JAX Metal backend is experimental. Some functionalities may not be available." ) # Enable 64-bit precision in JAX. if use_x64: logging.info("Enabling JAX to use 64-bit precision") jax.config.update("jax_enable_x64", True) # Warn about experimental usage of 32-bit precision. else: logging.warning( "Using 32-bit precision in JaxSim is still experimental, please avoid to use variable step integrators." ) def _np_options() -> None: import numpy as np np.set_printoptions(precision=5, suppress=True, linewidth=150, threshold=10_000) def _is_editable() -> bool: import importlib.util import pathlib import site # Get the ModuleSpec of jaxsim. jaxsim_spec = importlib.util.find_spec(name="jaxsim") # This can be None. If it's None, assume non-editable installation. if jaxsim_spec.origin is None: return False # Get the folder containing the jaxsim package. jaxsim_package_dir = str(pathlib.Path(jaxsim_spec.origin).parent.parent) # The installation is editable if the package dir is not in any {site|dist}-packages. return jaxsim_package_dir not in site.getsitepackages() def _get_default_logging_level() -> logging.LoggingLevel: """ Get the default logging level. Returns: The logging level to set. """ import os import sys # Allow to override the default logging level with an environment variable. if overriden_logging_level := os.environ.get("JAXSIM_LOGGING_LEVEL"): try: return logging.LoggingLevel[overriden_logging_level.upper()] except KeyError as exc: msg = "Invalid logging level defined in JAXSIM_LOGGING_LEVEL" raise RuntimeError(msg) from exc # If running under a debugger, set the logging level to DEBUG. if getattr(sys, "gettrace", lambda: None)(): return logging.LoggingLevel.DEBUG # If not running under a debugger, set the logging level to INFO or WARNING. # INFO for editable installations, WARNING for non-editable installations. # This is to avoid too verbose logging in non-editable installations. return ( logging.LoggingLevel.INFO if _is_editable() # noqa: F821 else logging.LoggingLevel.WARNING ) # Configure the logger with the default logging level. logging.configure(level=_get_default_logging_level()) # Configure JAX. _jnp_options() # Initialize the numpy print options. _np_options() del _jnp_options del _np_options del _get_default_logging_level del _is_editable from . import terrain # isort:skip from . import api, logging, math, rbda from .api.common import VelRepr ================================================ FILE: src/jaxsim/api/__init__.py ================================================ from . import common # isort:skip from . import model, data # isort:skip from . import ( actuation_model, com, contact, frame, integrators, joint, kin_dyn_parameters, link, ode, references, ) ================================================ FILE: src/jaxsim/api/actuation_model.py ================================================ import jax.numpy as jnp import jaxsim.api as js import jaxsim.typing as jtp def compute_resultant_torques( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, joint_force_references: jtp.Vector | None = None, ) -> jtp.Vector: """ Compute the resultant torques acting on the joints. Args: model: The model to consider. data: The data of the considered model. joint_force_references: The joint force references to apply. Returns: The resultant torques acting on the joints. """ # Build joint torques if not provided. τ_references = ( jnp.atleast_1d(joint_force_references.squeeze()) if joint_force_references is not None else jnp.zeros_like(data.joint_positions) ).astype(float) # ==================== # Enforce joint limits # ==================== τ_position_limit = jnp.zeros_like(τ_references).astype(float) if model.dofs() > 0: # Stiffness and damper parameters for the joint position limits. k_j = jnp.array( model.kin_dyn_parameters.joint_parameters.position_limit_spring ).astype(float) d_j = jnp.array( model.kin_dyn_parameters.joint_parameters.position_limit_damper ).astype(float) # Compute the joint position limit violations. lower_violation = jnp.clip( data.joint_positions - model.kin_dyn_parameters.joint_parameters.position_limits_min, max=0.0, ) upper_violation = jnp.clip( data.joint_positions - model.kin_dyn_parameters.joint_parameters.position_limits_max, min=0.0, ) # Compute the joint position limit torque. τ_position_limit -= jnp.diag(k_j) @ (lower_violation + upper_violation) τ_position_limit -= ( jnp.positive(τ_position_limit) * jnp.diag(d_j) @ data.joint_velocities ) # ==================== # Joint friction model # ==================== τ_friction = jnp.zeros_like(τ_references).astype(float) # Apply joint friction only if enabled in the actuation parameters. if model.dofs() > 0 and model.actuation_params.enable_friction: # Static and viscous joint friction parameters kc = jnp.array( model.kin_dyn_parameters.joint_parameters.friction_static ).astype(float) kv = jnp.array( model.kin_dyn_parameters.joint_parameters.friction_viscous ).astype(float) # Compute the joint friction torque. τ_friction = -( jnp.diag(kc) @ jnp.sign(data.joint_velocities) + jnp.diag(kv) @ data.joint_velocities ) # =============================== # Compute the total joint forces. # =============================== τ_total = τ_references + τ_friction + τ_position_limit τ_lim = tn_curve_fn(model=model, data=data) τ_total = jnp.clip(τ_total, -τ_lim, τ_lim) return τ_total def tn_curve_fn( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Vector: """ Compute the torque limits using the tn curve. Args: model: The model to consider. data: The data of the considered model. Returns: The torque limits. """ τ_max = model.actuation_params.torque_max # Max torque (Nm) ω_th = model.actuation_params.omega_th # Threshold speed (rad/s) ω_max = model.actuation_params.omega_max # Max speed for torque drop-off (rad/s) abs_vel = jnp.abs(data.joint_velocities) τ_lim = jnp.where( abs_vel <= ω_th, τ_max, jnp.where( abs_vel <= ω_max, τ_max * (1 - (abs_vel - ω_th) / (ω_max - ω_th)), 0.0 ), ) return τ_lim ================================================ FILE: src/jaxsim/api/com.py ================================================ import jax import jax.numpy as jnp import jaxsim.api as js import jaxsim.math import jaxsim.typing as jtp from .common import VelRepr @jax.jit @js.common.named_scope def com_position( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Vector: """ Compute the position of the center of mass of the model. Args: model: The model to consider. data: The data of the considered model. Returns: The position of the center of mass of the model w.r.t. the world frame. """ m = js.model.total_mass(model=model) W_H_L = data._link_transforms W_H_B = data._base_transform B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B) def B_p̃_LCoM(i) -> jtp.Vector: m = js.link.mass(model=model, link_index=i) L_p_LCoM = js.link.com_position( model=model, data=data, link_index=i, in_link_frame=True ) return m * B_H_W @ W_H_L[i] @ jnp.hstack([L_p_LCoM, 1]) com_links = jax.vmap(B_p̃_LCoM)(jnp.arange(model.number_of_links())) B_p̃_CoM = (1 / m) * com_links.sum(axis=0) B_p̃_CoM = B_p̃_CoM.at[3].set(1) return (W_H_B @ B_p̃_CoM)[0:3].astype(float) @jax.jit @js.common.named_scope def com_linear_velocity( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Vector: r""" Compute the linear velocity of the center of mass of the model. Args: model: The model to consider. data: The data of the considered model. Returns: The linear velocity of the center of mass of the model in the active representation. Note: The linear velocity of the center of mass is expressed in the mixed frame :math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the active velocity representation is either inertial-fixed or mixed, and :math:`[C] = [B]` if the active velocity representation is body-fixed. """ # Extract the linear component of the 6D average centroidal velocity. # This is expressed in G[B] in body-fixed representation, and in G[W] in # inertial-fixed or mixed representation. G_vl_WG = average_centroidal_velocity(model=model, data=data)[0:3] return G_vl_WG @jax.jit @js.common.named_scope def centroidal_momentum( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Vector: r""" Compute the centroidal momentum of the model. Args: model: The model to consider. data: The data of the considered model. Returns: The centroidal momentum of the model. Note: The centroidal momentum is expressed in the mixed frame :math:`({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`C = W` if the active velocity representation is either inertial-fixed or mixed, and :math:`C = B` if the active velocity representation is body-fixed. """ ν = data.generalized_velocity G_J = centroidal_momentum_jacobian(model=model, data=data) return G_J @ ν @jax.jit @js.common.named_scope def centroidal_momentum_jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: r""" Compute the Jacobian of the centroidal momentum of the model. Args: model: The model to consider. data: The data of the considered model. Returns: The Jacobian of the centroidal momentum of the model. Note: The frame corresponding to the output representation of this Jacobian is either :math:`G[W]`, if the active velocity representation is inertial-fixed or mixed, or :math:`G[B]`, if the active velocity representation is body-fixed. Note: This Jacobian is also known in the literature as Centroidal Momentum Matrix. """ # Compute the Jacobian of the total momentum with body-fixed output representation. # We convert the output representation either to G[W] or G[B] below. B_Jh = js.model.total_momentum_jacobian( model=model, data=data, output_vel_repr=VelRepr.Body ) W_H_B = data._base_transform B_H_W = jaxsim.math.Transform.inverse(W_H_B) W_p_CoM = com_position(model=model, data=data) match data.velocity_representation: case VelRepr.Inertial | VelRepr.Mixed: W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841 case VelRepr.Body: W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841 case _: raise ValueError(data.velocity_representation) # Compute the transform for 6D forces. G_Xf_B = jaxsim.math.Adjoint.from_transform(transform=B_H_W @ W_H_G).T return G_Xf_B @ B_Jh @jax.jit @js.common.named_scope def locked_centroidal_spatial_inertia( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ): """ Compute the locked centroidal spatial inertia of the model. Args: model: The model to consider. data: The data of the considered model. Returns: The locked centroidal spatial inertia of the model. """ with data.switch_velocity_representation(VelRepr.Body): B_Mbb_B = js.model.locked_spatial_inertia(model=model, data=data) W_H_B = data._base_transform W_p_CoM = com_position(model=model, data=data) match data.velocity_representation: case VelRepr.Inertial | VelRepr.Mixed: W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) # noqa: F841 case VelRepr.Body: W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM) # noqa: F841 case _: raise ValueError(data.velocity_representation) B_H_G = jaxsim.math.Transform.inverse(W_H_B) @ W_H_G B_Xv_G = jaxsim.math.Adjoint.from_transform(transform=B_H_G) G_Xf_B = B_Xv_G.transpose() return G_Xf_B @ B_Mbb_B @ B_Xv_G @jax.jit @js.common.named_scope def average_centroidal_velocity( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Vector: r""" Compute the average centroidal velocity of the model. Args: model: The model to consider. data: The data of the considered model. Returns: The average centroidal velocity of the model. Note: The average velocity is expressed in the mixed frame :math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the active velocity representation is either inertial-fixed or mixed, and :math:`[C] = [B]` if the active velocity representation is body-fixed. """ ν = data.generalized_velocity G_J = average_centroidal_velocity_jacobian(model=model, data=data) return G_J @ ν @jax.jit @js.common.named_scope def average_centroidal_velocity_jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: r""" Compute the Jacobian of the average centroidal velocity of the model. Args: model: The model to consider. data: The data of the considered model. Returns: The Jacobian of the average centroidal velocity of the model. Note: The frame corresponding to the output representation of this Jacobian is either :math:`G[W]`, if the active velocity representation is inertial-fixed or mixed, or :math:`G[B]`, if the active velocity representation is body-fixed. """ G_J = centroidal_momentum_jacobian(model=model, data=data) G_Mbb = locked_centroidal_spatial_inertia(model=model, data=data) return jnp.linalg.inv(G_Mbb) @ G_J @jax.jit @js.common.named_scope def bias_acceleration( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Vector: r""" Compute the bias linear acceleration of the center of mass. Args: model: The model to consider. data: The data of the considered model. Returns: The bias linear acceleration of the center of mass in the active representation. Note: The bias acceleration is expressed in the mixed frame :math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the active velocity representation is either inertial-fixed or mixed, and :math:`[C] = [B]` if the active velocity representation is body-fixed. """ # Compute the pose of all links with forward kinematics. W_H_L = data._link_transforms # Compute the bias acceleration of all links by zeroing the generalized velocity # in the active representation. v̇_bias_WL = js.model.link_bias_accelerations(model=model, data=data) def other_representation_to_body( C_v̇_WL: jtp.Vector, C_v_WC: jtp.Vector, L_H_C: jtp.Matrix, L_v_LC: jtp.Vector ) -> jtp.Vector: """ Convert the body-fixed representation of the link bias acceleration C_v̇_WL expressed in a generic frame C to the body-fixed representation L_v̇_WL. """ L_X_C = jaxsim.math.Adjoint.from_transform(transform=L_H_C) C_X_L = jaxsim.math.Adjoint.inverse(L_X_C) L_v̇_WL = L_X_C @ (C_v̇_WL + jaxsim.math.Cross.vx(C_X_L @ L_v_LC) @ C_v_WC) return L_v̇_WL # We need here to get the body-fixed bias acceleration of the links. # Since it's computed in the active representation, we need to convert it to body. match data.velocity_representation: case VelRepr.Body: L_a_bias_WL = v̇_bias_WL case VelRepr.Inertial: C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL # noqa: F841 C_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 L_H_C = L_H_W = jax.vmap(jaxsim.math.Transform.inverse)(W_H_L) # noqa: F841 L_v_LC = L_v_LW = jax.vmap( # noqa: F841 lambda i: -js.link.velocity( model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body ) )(jnp.arange(model.number_of_links())) L_a_bias_WL = jax.vmap( lambda i: other_representation_to_body( C_v̇_WL=C_v̇_WL[i], C_v_WC=C_v_WC, L_H_C=L_H_C[i], L_v_LC=L_v_LC[i], ) )(jnp.arange(model.number_of_links())) case VelRepr.Mixed: C_v̇_WL = LW_v̇_bias_WL = v̇_bias_WL # noqa: F841 C_v_WC = LW_v_W_LW = jax.vmap( # noqa: F841 lambda i: js.link.velocity( model=model, data=data, link_index=i, output_vel_repr=VelRepr.Mixed ) .at[3:6] .set(jnp.zeros(3)) )(jnp.arange(model.number_of_links())) L_H_C = L_H_LW = jax.vmap( # noqa: F841 lambda W_H_L: jaxsim.math.Transform.inverse( W_H_L.at[0:3, 3].set(jnp.zeros(3)) ) )(W_H_L) L_v_LC = L_v_L_LW = jax.vmap( # noqa: F841 lambda i: -js.link.velocity( model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body ) .at[0:3] .set(jnp.zeros(3)) )(jnp.arange(model.number_of_links())) L_a_bias_WL = jax.vmap( lambda i: other_representation_to_body( C_v̇_WL=C_v̇_WL[i], C_v_WC=C_v_WC[i], L_H_C=L_H_C[i], L_v_LC=L_v_LC[i], ) )(jnp.arange(model.number_of_links())) case _: raise ValueError(data.velocity_representation) # Compute the bias of the 6D momentum derivative. def bias_momentum_derivative_term( link_index: jtp.Int, L_a_bias_WL: jtp.Vector ) -> jtp.Vector: # Get the body-fixed 6D inertia matrix. L_M_L = js.link.spatial_inertia(model=model, link_index=link_index) # Compute the body-fixed 6D velocity. L_v_WL = js.link.velocity( model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Body ) # Compute the world-to-link transformations for 6D forces. W_Xf_L = jaxsim.math.Adjoint.from_transform( transform=W_H_L[link_index], inverse=True ).T # Compute the contribution of the link to the bias acceleration of the CoM. W_ḣ_bias_link_contribution = W_Xf_L @ ( L_M_L @ L_a_bias_WL + jaxsim.math.Cross.vx_star(L_v_WL) @ L_M_L @ L_v_WL ) return W_ḣ_bias_link_contribution # Sum the contributions of all links to the bias acceleration of the CoM. W_ḣ_bias = jax.vmap(bias_momentum_derivative_term)( jnp.arange(model.number_of_links()), L_a_bias_WL ).sum(axis=0) # Compute the total mass of the model. m = js.model.total_mass(model=model) # Compute the position of the CoM. W_p_CoM = com_position(model=model, data=data) match data.velocity_representation: # G := G[W] = (W_p_CoM, [W]) case VelRepr.Inertial | VelRepr.Mixed: W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) GW_Xf_W = jaxsim.math.Adjoint.from_transform(W_H_GW).T GW_ḣ_bias = GW_Xf_W @ W_ḣ_bias GW_v̇l_com_bias = GW_ḣ_bias[0:3] / m return GW_v̇l_com_bias # G := G[B] = (W_p_CoM, [B]) case VelRepr.Body: GB_Xf_W = jaxsim.math.Adjoint.from_transform( transform=data._base_transform.at[0:3].set(W_p_CoM) ).T GB_ḣ_bias = GB_Xf_W @ W_ḣ_bias GB_v̇l_com_bias = GB_ḣ_bias[0:3] / m return GB_v̇l_com_bias case _: raise ValueError(data.velocity_representation) ================================================ FILE: src/jaxsim/api/common.py ================================================ import abc import contextlib import dataclasses import enum import functools from collections.abc import Callable, Iterator from typing import ParamSpec, TypeVar import jax import jax.numpy as jnp import jax_dataclasses from jax_dataclasses import Static import jaxsim.typing as jtp from jaxsim.math import Adjoint from jaxsim.utils import JaxsimDataclass, Mutability try: from typing import Self except ImportError: from typing_extensions import Self _P = ParamSpec("_P") _R = TypeVar("_R") def named_scope(fn, name: str | None = None) -> Callable[_P, _R]: """Apply a JAX named scope to a function for improved profiling and clarity.""" @functools.wraps(fn) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: with jax.named_scope(name or fn.__name__): return fn(*args, **kwargs) return wrapper @enum.unique class VelRepr(enum.IntEnum): """ Enumeration of all supported 6D velocity representations. """ Body = enum.auto() Mixed = enum.auto() Inertial = enum.auto() @jax_dataclasses.pytree_dataclass class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC): """ Base class for model data structures with velocity representation. """ velocity_representation: Static[VelRepr] = dataclasses.field( default=VelRepr.Inertial, kw_only=True ) @contextlib.contextmanager def switch_velocity_representation( self, velocity_representation: VelRepr ) -> Iterator[Self]: """ Context manager to temporarily switch the velocity representation. Args: velocity_representation: The new velocity representation. Yields: The same object with the new velocity representation. """ original_representation = self.velocity_representation try: # First, we replace the velocity representation. with self.mutable_context( mutability=Mutability.MUTABLE_NO_VALIDATION, restore_after_exception=True, ): self.velocity_representation = velocity_representation # Then, we yield the data with changed representation. # We run this in a mutable context with restoration so that any exception # occurring, we restore the original object in case it was modified. with self.mutable_context( mutability=self.mutability(), restore_after_exception=True ): yield self finally: with self.mutable_context( mutability=Mutability.MUTABLE_NO_VALIDATION, restore_after_exception=True, ): self.velocity_representation = original_representation @staticmethod @functools.partial(jax.jit, static_argnames=["other_representation", "is_force"]) def inertial_to_other_representation( array: jtp.Array, other_representation: VelRepr, transform: jtp.Matrix, *, is_force: bool, ) -> jtp.Array: r""" Convert a 6D quantity from inertial-fixed to another representation. Args: array: The 6D quantity to convert. other_representation: The representation to convert to. transform: The :math:`W \mathbf{H}_O` transform, where :math:`O` is the reference frame of the other representation. is_force: Whether the quantity is a 6D force or a 6D velocity. Returns: The 6D quantity in the other representation. """ W_array = array W_H_O = transform match other_representation: case VelRepr.Inertial: return W_array case VelRepr.Body: if not is_force: O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True) O_array = jnp.einsum("...ij,...j->...i", O_Xv_W, W_array) else: O_Xf_W = Adjoint.from_transform(transform=W_H_O).swapaxes(-1, -2) O_array = jnp.einsum("...ij,...j->...i", O_Xf_W, W_array) return O_array case VelRepr.Mixed: W_H_OW = W_H_O.at[..., 0:3, 0:3].set(jnp.eye(3)) if not is_force: OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True) OW_array = jnp.einsum("...ij,...j->...i", OW_Xv_W, W_array) else: OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).swapaxes(-1, -2) OW_array = jnp.einsum("...ij,...j->...i", OW_Xf_W, W_array) return OW_array case _: raise ValueError(other_representation) @staticmethod @functools.partial(jax.jit, static_argnames=["other_representation", "is_force"]) def other_representation_to_inertial( array: jtp.Array, other_representation: VelRepr, transform: jtp.Matrix, *, is_force: bool, ) -> jtp.Array: r""" Convert a 6D quantity from another representation to inertial-fixed. Args: array: The 6D quantity to convert. other_representation: The representation to convert from. transform: The `math:W \mathbf{H}_O` transform, where `math:O` is the reference frame of the other representation. is_force: Whether the quantity is a 6D force or a 6D velocity. Returns: The 6D quantity in the inertial-fixed representation. """ O_array = array W_H_O = transform match other_representation: case VelRepr.Inertial: return O_array case VelRepr.Body: if not is_force: W_Xv_O = Adjoint.from_transform(W_H_O) W_array = jnp.einsum("...ij,...j->...i", W_Xv_O, O_array) else: W_Xf_O = Adjoint.from_transform( transform=W_H_O, inverse=True ).swapaxes(-1, -2) W_array = jnp.einsum("...ij,...j->...i", W_Xf_O, O_array) return W_array case VelRepr.Mixed: W_H_OW = W_H_O.at[..., 0:3, 0:3].set(jnp.eye(3)) if not is_force: W_Xv_BW = Adjoint.from_transform(W_H_OW) W_array = jnp.einsum("...ij,...j->...i", W_Xv_BW, O_array) else: W_Xf_BW = Adjoint.from_transform( transform=W_H_OW, inverse=True ).swapaxes(-1, -2) W_array = jnp.einsum("...ij,...j->...i", W_Xf_BW, O_array) return W_array case _: raise ValueError(other_representation) ================================================ FILE: src/jaxsim/api/contact.py ================================================ from __future__ import annotations import functools import jax import jax.numpy as jnp import jaxsim.api as js import jaxsim.exceptions import jaxsim.typing as jtp from jaxsim import logging from jaxsim.math import Adjoint, Cross, Transform from jaxsim.rbda.contacts import SoftContacts from .common import VelRepr @jax.jit @js.common.named_scope def collidable_point_kinematics( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> tuple[jtp.Matrix, jtp.Matrix]: """ Compute the position and 3D velocity of the collidable points in the world frame. Args: model: The model to consider. data: The data of the considered model. Returns: The position and velocity of the collidable points in the world frame. Note: The collidable point velocity is the plain coordinate derivative of the position. If we attach a frame C = (p_C, [C]) to the collidable point, it corresponds to the linear component of the mixed 6D frame velocity. """ W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel( model=model, link_transforms=data._link_transforms, link_velocities=data._link_velocities, ) return W_p_Ci, W_ṗ_Ci @jax.jit @js.common.named_scope def collidable_point_positions( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: """ Compute the position of the collidable points in the world frame. Args: model: The model to consider. data: The data of the considered model. Returns: The position of the collidable points in the world frame. """ W_p_Ci, _ = collidable_point_kinematics(model=model, data=data) return W_p_Ci @jax.jit @js.common.named_scope def collidable_point_velocities( model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: """ Compute the 3D velocity of the collidable points in the world frame. Args: model: The model to consider. data: The data of the considered model. Returns: The 3D velocity of the collidable points. """ _, W_ṗ_Ci = collidable_point_kinematics(model=model, data=data) return W_ṗ_Ci @functools.partial(jax.jit, static_argnames=["link_names"]) @js.common.named_scope def in_contact( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_names: tuple[str, ...] | None = None, ) -> jtp.Vector: """ Return whether the links are in contact with the terrain. Args: model: The model to consider. data: The data of the considered model. link_names: The names of the links to consider. If None, all links are considered. Returns: A boolean vector indicating whether the links are in contact with the terrain. """ if link_names is not None and set(link_names).difference(model.link_names()): raise ValueError("One or more link names are not part of the model") # Get the indices of the enabled collidable points. indices_of_enabled_collidable_points = ( model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points ) parent_link_idx_of_enabled_collidable_points = jnp.array( model.kin_dyn_parameters.contact_parameters.body, dtype=int )[indices_of_enabled_collidable_points] W_p_Ci = collidable_point_positions(model=model, data=data) terrain_height = jax.vmap(lambda x, y: model.terrain.height(x=x, y=y))( W_p_Ci[:, 0], W_p_Ci[:, 1] ) below_terrain = W_p_Ci[:, 2] <= terrain_height link_idxs = ( js.link.names_to_idxs(link_names=link_names, model=model) if link_names is not None else jnp.arange(model.number_of_links()) ) links_in_contact = jax.vmap( lambda link_index: jnp.where( parent_link_idx_of_enabled_collidable_points == link_index, below_terrain, jnp.zeros_like(below_terrain, dtype=bool), ).any() )(link_idxs) return links_in_contact def estimate_good_soft_contacts_parameters( *args, **kwargs ) -> jaxsim.rbda.contacts.ContactParamsTypes: """ Estimate good soft contacts parameters. Deprecated, use `estimate_good_contact_parameters` instead. """ msg = "This method is deprecated, please use `{}`." logging.warning(msg.format(estimate_good_contact_parameters.__name__)) return estimate_good_contact_parameters(*args, **kwargs) def estimate_good_contact_parameters( model: js.model.JaxSimModel, *, standard_gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, static_friction_coefficient: jtp.FloatLike = 0.5, number_of_active_collidable_points_steady_state: jtp.IntLike = 1, damping_ratio: jtp.FloatLike = 1.0, max_penetration: jtp.FloatLike | None = None, ) -> jaxsim.rbda.contacts.ContactParamsTypes: """ Estimate good contact parameters. Args: model: The model to consider. standard_gravity: The standard gravity acceleration. static_friction_coefficient: The static friction coefficient. number_of_active_collidable_points_steady_state: The number of active collidable points in steady state. damping_ratio: The damping ratio. max_penetration: The maximum penetration allowed. Returns: The estimated good contacts parameters. Note: This is primarily a convenience function for soft-like contact models. However, it provides with some good default parameters also for the other ones. Note: This method provides a good set of contacts parameters. The user is encouraged to fine-tune the parameters based on the specific application. """ if max_penetration is None: zero_data = js.data.JaxSimModelData.build(model=model) W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2] if model.floating_base(): W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1] W_pz_CoM = W_pz_CoM - W_pz_C.min() # Consider as default a 1% of the model center of mass height. max_penetration = 0.01 * W_pz_CoM nc = number_of_active_collidable_points_steady_state return model.contact_model._parameters_class().build_default_from_jaxsim_model( model=model, standard_gravity=standard_gravity, static_friction_coefficient=static_friction_coefficient, max_penetration=max_penetration, number_of_active_collidable_points_steady_state=nc, damping_ratio=damping_ratio, ) @jax.jit @js.common.named_scope def transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array: r""" Return the pose of the enabled collidable points. Args: model: The model to consider. data: The data of the considered model. Returns: The stacked SE(3) matrices of all enabled collidable points. Note: Each collidable point is implicitly associated with a frame :math:`C = ({}^W p_C, [L])`, where :math:`{}^W p_C` is the position of the collidable point and :math:`[L]` is the orientation frame of the link it is rigidly attached to. """ # Get the indices of the enabled collidable points. indices_of_enabled_collidable_points = ( model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points ) parent_link_idx_of_enabled_collidable_points = jnp.array( model.kin_dyn_parameters.contact_parameters.body, dtype=int )[indices_of_enabled_collidable_points] # Get the transforms of the parent link of all collidable points. W_H_L = data._link_transforms[parent_link_idx_of_enabled_collidable_points] L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[ indices_of_enabled_collidable_points ] # Build the link-to-point transform from the displacement between the link frame L # and the implicit contact frame C. L_H_C = jax.vmap(jnp.eye(4).at[0:3, 3].set)(L_p_Ci) # Compose the work-to-link and link-to-point transforms. return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C) @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) @js.common.named_scope def jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, output_vel_repr: VelRepr | None = None, ) -> jtp.Array: r""" Return the free-floating Jacobian of the enabled collidable points. Args: model: The model to consider. data: The data of the considered model. output_vel_repr: The output velocity representation of the free-floating jacobian. Returns: The stacked :math:`6 \times (6+n)` free-floating jacobians of the frames associated to the enabled collidable points. Note: Each collidable point is implicitly associated with a frame :math:`C = ({}^W p_C, [L])`, where :math:`{}^W p_C` is the position of the collidable point and :math:`[L]` is the orientation frame of the link it is rigidly attached to. """ output_vel_repr = ( output_vel_repr if output_vel_repr is not None else data.velocity_representation ) # Get the indices of the enabled collidable points. indices_of_enabled_collidable_points = ( model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points ) parent_link_idx_of_enabled_collidable_points = jnp.array( model.kin_dyn_parameters.contact_parameters.body, dtype=int )[indices_of_enabled_collidable_points] # Compute the Jacobians of all links. W_J_WL = js.model.generalized_free_floating_jacobian( model=model, data=data, output_vel_repr=VelRepr.Inertial ) # Compute the contact Jacobian. # In inertial-fixed output representation, the Jacobian of the parent link is also # the Jacobian of the frame C implicitly associated with the collidable point. W_J_WC = W_J_WL[parent_link_idx_of_enabled_collidable_points] # Adjust the output representation. match output_vel_repr: case VelRepr.Inertial: O_J_WC = W_J_WC case VelRepr.Body: W_H_C = transforms(model=model, data=data) def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: C_X_W = jaxsim.math.Adjoint.from_transform( transform=W_H_C, inverse=True ) C_J_WC = C_X_W @ W_J_WC return C_J_WC O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC) case VelRepr.Mixed: W_H_C = transforms(model=model, data=data) def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) CW_X_W = jaxsim.math.Adjoint.from_transform( transform=W_H_CW, inverse=True ) CW_J_WC = CW_X_W @ W_J_WC return CW_J_WC O_J_WC = jax.vmap(mixed_jacobian)(W_H_C, W_J_WC) case _: raise ValueError(output_vel_repr) return O_J_WC @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) @js.common.named_scope def jacobian_derivative( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: r""" Compute the derivative of the free-floating jacobian of the enabled collidable points. Args: model: The model to consider. data: The data of the considered model. output_vel_repr: The output velocity representation of the free-floating jacobian derivative. Returns: The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the enabled collidable points. Note: The input representation of the free-floating jacobian derivative is the active velocity representation. """ output_vel_repr = ( output_vel_repr if output_vel_repr is not None else data.velocity_representation ) indices_of_enabled_collidable_points = ( model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points ) # Get the index of the parent link and the position of the collidable point. parent_link_idx_of_enabled_collidable_points = jnp.array( model.kin_dyn_parameters.contact_parameters.body, dtype=int )[indices_of_enabled_collidable_points] L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[ indices_of_enabled_collidable_points ] # Get the transforms of all the parent links. W_H_Li = data._link_transforms # Get the link velocities. W_v_WLi = data._link_velocities # ===================================================== # Compute quantities to adjust the input representation # ===================================================== def compute_T(model: js.model.JaxSimModel, X: jtp.Matrix) -> jtp.Matrix: In = jnp.eye(model.dofs()) T = jax.scipy.linalg.block_diag(X, In) return T def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: On = jnp.zeros(shape=(model.dofs(), model.dofs())) Ṫ = jax.scipy.linalg.block_diag(Ẋ, On) return Ṫ # Compute the operator to change the representation of ν, and its # time derivative. match data.velocity_representation: case VelRepr.Inertial: W_H_W = jnp.eye(4) W_X_W = Adjoint.from_transform(transform=W_H_W) W_Ẋ_W = jnp.zeros((6, 6)) T = compute_T(model=model, X=W_X_W) Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W) case VelRepr.Body: W_H_B = data._base_transform W_X_B = Adjoint.from_transform(transform=W_H_B) B_v_WB = data.base_velocity B_vx_WB = Cross.vx(B_v_WB) W_Ẋ_B = W_X_B @ B_vx_WB T = compute_T(model=model, X=W_X_B) Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B) case VelRepr.Mixed: W_H_B = data._base_transform W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) W_X_BW = Adjoint.from_transform(transform=W_H_BW) BW_v_WB = data.base_velocity BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) BW_vx_W_BW = Cross.vx(BW_v_W_BW) W_Ẋ_BW = W_X_BW @ BW_vx_W_BW T = compute_T(model=model, X=W_X_BW) Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW) case _: raise ValueError(data.velocity_representation) # ===================================================== # Compute quantities to adjust the output representation # ===================================================== with data.switch_velocity_representation(VelRepr.Inertial): # Compute the Jacobian of the parent link in inertial representation. W_J_WL_W = js.model.generalized_free_floating_jacobian( model=model, data=data, ) # Compute the Jacobian derivative of the parent link in inertial representation. W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative( model=model, data=data, ) def compute_O_J̇_WC_I( L_p_C: jtp.Vector, parent_link_idx: jtp.Int, W_H_L: jtp.Matrix, ) -> jtp.Matrix: match output_vel_repr: case VelRepr.Inertial: O_X_W = W_X_W = Adjoint.from_transform( # noqa: F841 transform=jnp.eye(4) ) O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6)) # noqa: F841 case VelRepr.Body: L_H_C = Transform.from_rotation_and_translation(translation=L_p_C) W_H_C = W_H_L[parent_link_idx] @ L_H_C O_X_W = C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) W_v_WC = W_v_WLi[parent_link_idx] W_vx_WC = Cross.vx(W_v_WC) O_Ẋ_W = C_Ẋ_W = -C_X_W @ W_vx_WC # noqa: F841 case VelRepr.Mixed: L_H_C = Transform.from_rotation_and_translation(translation=L_p_C) W_H_C = W_H_L[parent_link_idx] @ L_H_C W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) CW_H_W = Transform.inverse(W_H_CW) O_X_W = CW_X_W = Adjoint.from_transform(transform=CW_H_W) CW_v_WC = CW_X_W @ W_v_WLi[parent_link_idx] W_v_W_CW = jnp.zeros(6).at[0:3].set(CW_v_WC[0:3]) W_vx_W_CW = Cross.vx(W_v_W_CW) O_Ẋ_W = CW_Ẋ_W = -CW_X_W @ W_vx_W_CW # noqa: F841 case _: raise ValueError(output_vel_repr) O_J̇_WC_I = jnp.zeros(shape=(6, 6 + model.dofs())) O_J̇_WC_I += O_Ẋ_W @ W_J_WL_W[parent_link_idx] @ T O_J̇_WC_I += O_X_W @ W_J̇_WL_W[parent_link_idx] @ T O_J̇_WC_I += O_X_W @ W_J_WL_W[parent_link_idx] @ Ṫ return O_J̇_WC_I O_J̇_WC = jax.vmap(compute_O_J̇_WC_I, in_axes=(0, 0, None))( L_p_Ci, parent_link_idx_of_enabled_collidable_points, W_H_Li ) return O_J̇_WC @jax.jit @js.common.named_scope def link_contact_forces( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_forces: jtp.MatrixLike | None = None, joint_torques: jtp.VectorLike | None = None, ) -> tuple[jtp.Matrix, dict[str, jtp.Matrix]]: """ Compute the 6D contact forces of all links of the model in inertial representation. Args: model: The model to consider. data: The data of the considered model. link_forces: The 6D external forces to apply to the links expressed in inertial representation joint_torques: The joint torques acting on the joints. Returns: A `(nL, 6)` array containing the stacked 6D contact forces of the links, expressed in inertial representation. """ # Compute the contact forces for each collidable point with the active contact model. W_f_C, aux_dict = model.contact_model.compute_contact_forces( model=model, data=data, **( dict(link_forces=link_forces, joint_force_references=joint_torques) if not isinstance(model.contact_model, SoftContacts) else {} ), ) # Compute the 6D forces applied to the links equivalent to the forces applied # to the frames associated to the collidable points. W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C) return W_f_L, aux_dict def link_forces_from_contact_forces( model: js.model.JaxSimModel, *, contact_forces: jtp.MatrixLike, ) -> jtp.Matrix: """ Compute the link forces from the contact forces. Args: model: The robot model considered by the contact model. contact_forces: The contact forces computed by the contact model. Returns: The 6D contact forces applied to the links and expressed in the frame of the velocity representation of data. """ # Get the object storing the contact parameters of the model. contact_parameters = model.kin_dyn_parameters.contact_parameters # Extract the indices corresponding to the enabled collidable points. indices_of_enabled_collidable_points = ( contact_parameters.indices_of_enabled_collidable_points ) # Convert the contact forces to a JAX array. W_f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze()) # Construct the vector defining the parent link index of each collidable point. # We use this vector to sum the 6D forces of all collidable points rigidly # attached to the same link. parent_link_index_of_collidable_points = jnp.array( contact_parameters.body, dtype=int )[indices_of_enabled_collidable_points] # Create the mask that associate each collidable point to their parent link. # We use this mask to sum the collidable points to the right link. mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange( model.number_of_links() ) # Sum the forces of all collidable points rigidly attached to a body. # Since the contact forces W_f_C are expressed in the world frame, # we don't need any coordinate transformation. W_f_L = mask.T @ W_f_C return W_f_L ================================================ FILE: src/jaxsim/api/data.py ================================================ from __future__ import annotations import dataclasses import functools from collections.abc import Sequence try: from typing import Self, override except ImportError: from typing_extensions import override, Self import jax import jax.numpy as jnp import jax.scipy.spatial.transform import jax_dataclasses import jaxsim.api as js import jaxsim.math import jaxsim.rbda import jaxsim.typing as jtp from . import common from .common import VelRepr @jax_dataclasses.pytree_dataclass class JaxSimModelData(common.ModelDataWithVelocityRepresentation): """ Class storing the state of the physics model dynamics. Attributes: joint_positions: The vector of joint positions. joint_velocities: The vector of joint velocities. base_position: The 3D position of the base link. base_quaternion: The quaternion defining the orientation of the base link. base_linear_velocity: The linear velocity of the base link in inertial-fixed representation. base_angular_velocity: The angular velocity of the base link in inertial-fixed representation. base_transform: The base transform. joint_transforms: The joint transforms. link_transforms: The link transforms. link_velocities: The link velocities in inertial-fixed representation. """ # Joint state _joint_positions: jtp.Vector _joint_velocities: jtp.Vector # Base state _base_quaternion: jtp.Vector _base_linear_velocity: jtp.Vector _base_angular_velocity: jtp.Vector _base_position: jtp.Vector # Cached computations. _base_transform: jtp.Matrix = dataclasses.field(repr=False, default=None) _joint_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None) _link_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None) _link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None) # Extended state for soft and rigid contact models. contact_state: dict[str, jtp.Array] = dataclasses.field(default_factory=dict) @staticmethod def build( model: js.model.JaxSimModel, base_position: jtp.VectorLike | None = None, base_quaternion: jtp.VectorLike | None = None, joint_positions: jtp.VectorLike | None = None, base_linear_velocity: jtp.VectorLike | None = None, base_angular_velocity: jtp.VectorLike | None = None, joint_velocities: jtp.VectorLike | None = None, contact_state: dict[str, jtp.Array] | None = None, velocity_representation: VelRepr = VelRepr.Mixed, ) -> JaxSimModelData: """ Create a `JaxSimModelData` object with the given state. Args: model: The model for which to create the state. base_position: The base position. base_quaternion: The base orientation as a quaternion. joint_positions: The joint positions. base_linear_velocity: The base linear velocity in the selected representation. base_angular_velocity: The base angular velocity in the selected representation. joint_velocities: The joint velocities. velocity_representation: The velocity representation to use. It defaults to mixed if not provided. contact_state: The optional contact state. Returns: A `JaxSimModelData` initialized with the given state. """ base_position = jnp.array( base_position if base_position is not None else jnp.zeros(3), dtype=float, ).squeeze() base_quaternion = jnp.array( ( base_quaternion if base_quaternion is not None else jnp.array([1.0, 0, 0, 0]) ), dtype=float, ).squeeze() base_linear_velocity = jnp.array( base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3), dtype=float, ).squeeze() base_angular_velocity = jnp.array( ( base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3) ), dtype=float, ).squeeze() joint_positions = jnp.atleast_1d( jnp.array( ( joint_positions if joint_positions is not None else jnp.zeros(model.dofs()) ), dtype=float, ).squeeze() ) joint_velocities = jnp.atleast_1d( jnp.array( ( joint_velocities if joint_velocities is not None else jnp.zeros(model.dofs()) ), dtype=float, ).squeeze() ) W_H_B = jaxsim.math.Transform.from_quaternion_and_translation( translation=base_position, quaternion=base_quaternion ) W_v_WB = JaxSimModelData.other_representation_to_inertial( array=jnp.hstack([base_linear_velocity, base_angular_velocity]), other_representation=velocity_representation, transform=W_H_B, is_force=False, ).astype(float) joint_transforms = model.kin_dyn_parameters.joint_transforms( joint_positions=joint_positions, base_transform=W_H_B ) link_transforms, link_velocities_inertial = ( jaxsim.rbda.forward_kinematics_model( model=model, base_position=base_position, base_quaternion=base_quaternion, joint_positions=joint_positions, base_linear_velocity_inertial=W_v_WB[0:3], base_angular_velocity_inertial=W_v_WB[3:6], joint_velocities=joint_velocities, joint_transforms=joint_transforms, ) ) contact_state = contact_state or {} if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts): contact_state["tangential_deformation"] = contact_state.get( "tangential_deformation", jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.point), ) model_data = JaxSimModelData( velocity_representation=velocity_representation, _base_quaternion=base_quaternion, _base_position=base_position, _joint_positions=joint_positions, _base_linear_velocity=W_v_WB[0:3], _base_angular_velocity=W_v_WB[3:6], _joint_velocities=joint_velocities, _base_transform=W_H_B, _joint_transforms=joint_transforms, _link_transforms=link_transforms, _link_velocities=link_velocities_inertial, contact_state=contact_state, ) if not model_data.valid(model=model): raise ValueError( "The built state is not compatible with the model.", model_data ) return model_data @staticmethod def zero( model: js.model.JaxSimModel, velocity_representation: VelRepr = VelRepr.Mixed, ) -> JaxSimModelData: """ Create a `JaxSimModelData` object with zero state. Args: model: The model for which to create the state. velocity_representation: The velocity representation to use. It defaults to mixed if not provided. Returns: A `JaxSimModelData` initialized with zero state. """ return JaxSimModelData.build( model=model, velocity_representation=velocity_representation ) # ================== # Extract quantities # ================== @property def joint_positions(self) -> jtp.Vector: """ Get the joint positions. Returns: The joint positions. """ return self._joint_positions @property def joint_velocities(self) -> jtp.Vector: """ Get the joint velocities. Returns: The joint velocities. """ return self._joint_velocities @property def base_quaternion(self) -> jtp.Vector: """ Get the base quaternion. Returns: The base quaternion. """ return self._base_quaternion @property def base_position(self) -> jtp.Vector: """ Get the base position. Returns: The base position. """ return self._base_position @property def base_orientation(self) -> jtp.Matrix: """ Get the base orientation. Returns: The base orientation. """ # Extract the base quaternion. W_Q_B = self.base_quaternion # Always normalize the quaternion to avoid numerical issues. # If the active scheme does not integrate the quaternion on its manifold, # we introduce a Baumgarte stabilization to let the quaternion converge to # a unit quaternion. In this case, it is not guaranteed that the quaternion # stored in the state is a unit quaternion. norm = jaxsim.math.safe_norm(W_Q_B, axis=-1, keepdims=True) W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0)) return W_Q_B @property def base_velocity(self) -> jtp.Vector: """ Get the base 6D velocity. Returns: The base 6D velocity in the active representation. """ W_v_WB = jnp.concatenate( [self._base_linear_velocity, self._base_angular_velocity], axis=-1 ) W_H_B = self._base_transform return ( JaxSimModelData.inertial_to_other_representation( array=W_v_WB, other_representation=self.velocity_representation, transform=W_H_B, is_force=False, ) .squeeze() .astype(float) ) @property def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]: r""" Get the generalized position :math:`\mathbf{q} = ({}^W \mathbf{H}_B, \mathbf{s}) \in \text{SO}(3) \times \mathbb{R}^n`. Returns: A tuple containing the base transform and the joint positions. """ return self._base_transform, self.joint_positions @property def generalized_velocity(self) -> jtp.Vector: r""" Get the generalized velocity. :math:`\boldsymbol{\nu} = (\boldsymbol{v}_{W,B};\, \boldsymbol{\omega}_{W,B};\, \mathbf{s}) \in \mathbb{R}^{6+n}` Returns: The generalized velocity in the active representation. """ return ( jnp.hstack([self.base_velocity, self.joint_velocities]) .squeeze() .astype(float) ) @property def base_transform(self) -> jtp.Matrix: """ Get the base transform. Returns: The base transform. """ return self._base_transform # ================ # Store quantities # ================ @js.common.named_scope @jax.jit def reset_base_quaternion( self, model: js.model.JaxSimModel, base_quaternion: jtp.VectorLike ) -> Self: """ Reset the base quaternion. Args: model: The JaxSim model to use. base_quaternion: The base orientation as a quaternion. Returns: The updated `JaxSimModelData` object. """ W_Q_B = jnp.array(base_quaternion, dtype=float) norm = jaxsim.math.safe_norm(W_Q_B, axis=-1) W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0)) return self.replace(model=model, base_quaternion=W_Q_B) @js.common.named_scope @jax.jit def reset_base_pose( self, model: js.model.JaxSimModel, base_pose: jtp.MatrixLike ) -> Self: """ Reset the base pose. Args: model: The JaxSim model to use. base_pose: The base pose as an SE(3) matrix. Returns: The updated `JaxSimModelData` object. """ base_pose = jnp.array(base_pose) W_p_B = base_pose[0:3, 3] W_Q_B = jaxsim.math.Quaternion.from_dcm(dcm=base_pose[0:3, 0:3]) return self.replace( model=model, base_position=W_p_B, base_quaternion=W_Q_B, ) @override def replace( self, model: js.model.JaxSimModel, joint_positions: jtp.Vector | None = None, joint_velocities: jtp.Vector | None = None, base_quaternion: jtp.Vector | None = None, base_linear_velocity: jtp.Vector | None = None, base_angular_velocity: jtp.Vector | None = None, base_position: jtp.Vector | None = None, *, contact_state: dict[str, jtp.Array] | None = None, validate: bool = False, ) -> Self: """ Replace the attributes of the `JaxSimModelData` object. """ if joint_positions is None: joint_positions = self.joint_positions if joint_velocities is None: joint_velocities = self.joint_velocities if base_quaternion is None: base_quaternion = self.base_quaternion if base_position is None: base_position = self.base_position if contact_state is None: contact_state = self.contact_state # Normalize the quaternion to avoid numerical issues. base_quaternion_norm = jaxsim.math.safe_norm( base_quaternion, axis=-1, keepdims=True ) base_quaternion = base_quaternion / jnp.where( base_quaternion_norm == 0, 1.0, base_quaternion_norm ) joint_positions = jnp.atleast_2d(joint_positions.squeeze()).astype(float) joint_velocities = jnp.atleast_2d(joint_velocities.squeeze()).astype(float) base_quaternion = jnp.atleast_2d(base_quaternion.squeeze()).astype(float) base_position = jnp.atleast_2d(base_position.squeeze()).astype(float) base_transform = jaxsim.math.Transform.from_quaternion_and_translation( translation=base_position, quaternion=base_quaternion ).reshape((-1, 4, 4)) joint_transforms = jax.vmap(model.kin_dyn_parameters.joint_transforms)( joint_positions=joint_positions, base_transform=base_transform, ) if base_linear_velocity is None and base_angular_velocity is None: base_linear_velocity_inertial = self._base_linear_velocity base_angular_velocity_inertial = self._base_angular_velocity else: if base_linear_velocity is None: base_linear_velocity = self.base_velocity[:3] if base_angular_velocity is None: base_angular_velocity = self.base_velocity[3:] base_linear_velocity = jnp.atleast_1d(base_linear_velocity.squeeze()) base_angular_velocity = jnp.atleast_1d(base_angular_velocity.squeeze()) W_v_WB = JaxSimModelData.other_representation_to_inertial( array=jnp.hstack([base_linear_velocity, base_angular_velocity]), other_representation=self.velocity_representation, transform=base_transform, is_force=False, ).astype(float) base_linear_velocity_inertial, base_angular_velocity_inertial = ( W_v_WB[..., :3], W_v_WB[..., 3:], ) link_transforms, link_velocities = jax.vmap( jaxsim.rbda.forward_kinematics_model, in_axes=(None,) )( model, base_position=base_position, base_quaternion=base_quaternion, joint_positions=joint_positions, joint_velocities=joint_velocities, base_linear_velocity_inertial=jnp.atleast_2d(base_linear_velocity_inertial), base_angular_velocity_inertial=jnp.atleast_2d( base_angular_velocity_inertial ), joint_transforms=joint_transforms, ) # Adjust the output shapes. joint_positions = joint_positions.reshape(self._joint_positions.shape) joint_velocities = joint_velocities.reshape(self._joint_velocities.shape) base_quaternion = base_quaternion.reshape(self._base_quaternion.shape) base_linear_velocity_inertial = base_linear_velocity_inertial.reshape( self._base_linear_velocity.shape ) base_angular_velocity_inertial = base_angular_velocity_inertial.reshape( self._base_angular_velocity.shape ) base_position = base_position.reshape(self._base_position.shape) base_transform = base_transform.reshape(self._base_transform.shape) joint_transforms = joint_transforms.reshape(self._joint_transforms.shape) link_transforms = link_transforms.reshape(self._link_transforms.shape) link_velocities = link_velocities.reshape(self._link_velocities.shape) return super().replace( _joint_positions=joint_positions, _joint_velocities=joint_velocities, _base_quaternion=base_quaternion, _base_linear_velocity=base_linear_velocity_inertial, _base_angular_velocity=base_angular_velocity_inertial, _base_position=base_position, _base_transform=base_transform, _joint_transforms=joint_transforms, _link_transforms=link_transforms, _link_velocities=link_velocities, contact_state=contact_state, validate=validate, ) def valid(self, model: js.model.JaxSimModel) -> bool: """ Check if the `JaxSimModelData` is valid for a given `JaxSimModel`. Args: model: The `JaxSimModel` to validate the `JaxSimModelData` against. Returns: `True` if the `JaxSimModelData` is valid for the given model, `False` otherwise. """ if self._joint_positions.shape != (model.dofs(),): return False if self._joint_velocities.shape != (model.dofs(),): return False if self._base_position.shape != (3,): return False if self._base_quaternion.shape != (4,): return False if self._base_linear_velocity.shape != (3,): return False if self._base_angular_velocity.shape != (3,): return False return True @functools.partial(jax.jit, static_argnames=["velocity_representation", "base_rpy_seq"]) def random_model_data( model: js.model.JaxSimModel, *, key: jax.Array | None = None, velocity_representation: VelRepr | None = None, base_pos_bounds: tuple[ jtp.FloatLike | Sequence[jtp.FloatLike], jtp.FloatLike | Sequence[jtp.FloatLike], ] = ((-1, -1, 0.5), 1.0), base_rpy_bounds: tuple[ jtp.FloatLike | Sequence[jtp.FloatLike], jtp.FloatLike | Sequence[jtp.FloatLike], ] = (-jnp.pi, jnp.pi), base_rpy_seq: str = "XYZ", joint_pos_bounds: ( tuple[ jtp.FloatLike | Sequence[jtp.FloatLike], jtp.FloatLike | Sequence[jtp.FloatLike], ] | None ) = None, base_vel_lin_bounds: tuple[ jtp.FloatLike | Sequence[jtp.FloatLike], jtp.FloatLike | Sequence[jtp.FloatLike], ] = (-1.0, 1.0), base_vel_ang_bounds: tuple[ jtp.FloatLike | Sequence[jtp.FloatLike], jtp.FloatLike | Sequence[jtp.FloatLike], ] = (-1.0, 1.0), joint_vel_bounds: tuple[ jtp.FloatLike | Sequence[jtp.FloatLike], jtp.FloatLike | Sequence[jtp.FloatLike], ] = (-1.0, 1.0), ) -> JaxSimModelData: """ Randomly generate a `JaxSimModelData` object. Args: model: The target model for the random data. key: The random key. velocity_representation: The velocity representation to use. base_pos_bounds: The bounds for the base position. base_rpy_bounds: The bounds for the euler angles used to build the base orientation. base_rpy_seq: The sequence of axes for rotation (using `Rotation` from scipy). joint_pos_bounds: The bounds for the joint positions (reading the joint limits if None). base_vel_lin_bounds: The bounds for the base linear velocity. base_vel_ang_bounds: The bounds for the base angular velocity. joint_vel_bounds: The bounds for the joint velocities. Returns: A `JaxSimModelData` object with random data. """ key = key if key is not None else jax.random.PRNGKey(seed=0) k1, k2, k3, k4, k5, k6 = jax.random.split(key, num=6) p_min = jnp.array(base_pos_bounds[0], dtype=float) p_max = jnp.array(base_pos_bounds[1], dtype=float) rpy_min = jnp.array(base_rpy_bounds[0], dtype=float) rpy_max = jnp.array(base_rpy_bounds[1], dtype=float) v_min = jnp.array(base_vel_lin_bounds[0], dtype=float) v_max = jnp.array(base_vel_lin_bounds[1], dtype=float) ω_min = jnp.array(base_vel_ang_bounds[0], dtype=float) ω_max = jnp.array(base_vel_ang_bounds[1], dtype=float) ṡ_min, ṡ_max = joint_vel_bounds base_position = jax.random.uniform(key=k1, shape=(3,), minval=p_min, maxval=p_max) base_quaternion = jaxsim.math.Quaternion.to_wxyz( xyzw=jax.scipy.spatial.transform.Rotation.from_euler( seq=base_rpy_seq, angles=jax.random.uniform( key=k2, shape=(3,), minval=rpy_min, maxval=rpy_max ), ).as_quat() ) ( joint_positions, joint_velocities, base_linear_velocity, base_angular_velocity, ) = (None,) * 4 if model.number_of_joints() > 0: s_min, s_max = ( jnp.array(joint_pos_bounds, dtype=float) if joint_pos_bounds is not None else (None, None) ) joint_positions = ( js.joint.random_joint_positions(model=model, key=k3) if (s_min is None or s_max is None) else jax.random.uniform( key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max ) ) joint_velocities = jax.random.uniform( key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max ) if model.floating_base(): base_linear_velocity = jax.random.uniform( key=k5, shape=(3,), minval=v_min, maxval=v_max ) base_angular_velocity = jax.random.uniform( key=k6, shape=(3,), minval=ω_min, maxval=ω_max ) return JaxSimModelData.build( model=model, base_position=base_position, base_quaternion=base_quaternion, joint_positions=joint_positions, joint_velocities=joint_velocities, base_linear_velocity=base_linear_velocity, base_angular_velocity=base_angular_velocity, **( {"velocity_representation": velocity_representation} if velocity_representation is not None else {} ), ) ================================================ FILE: src/jaxsim/api/frame.py ================================================ import functools from collections.abc import Sequence import jax import jax.numpy as jnp import jaxsim.api as js import jaxsim.typing as jtp from jaxsim import exceptions from jaxsim.math import Adjoint, Cross from .common import VelRepr # ======================= # Index-related functions # ======================= @jax.jit @js.common.named_scope def idx_of_parent_link( model: js.model.JaxSimModel, *, frame_index: jtp.IntLike ) -> jtp.Int: """ Get the index of the link to which the frame is rigidly attached. Args: model: The model to consider. frame_index: The index of the frame. Returns: The index of the frame's parent link. """ n_l = model.number_of_links() n_f = len(model.frame_names()) exceptions.raise_value_error_if( condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(), msg="Invalid frame index '{idx}'", idx=frame_index, ) return jnp.array(model.kin_dyn_parameters.frame_parameters.body)[ frame_index - model.number_of_links() ] @functools.partial(jax.jit, static_argnames="frame_name") @js.common.named_scope def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> jtp.Int: """ Convert the name of a frame to its index. Args: model: The model to consider. frame_name: The name of the frame. Returns: The index of the frame. """ if frame_name not in model.frame_names(): raise ValueError(f"Frame '{frame_name}' not found in the model.") return ( jnp.array( model.number_of_links() + model.kin_dyn_parameters.frame_parameters.name.index(frame_name) ) .astype(int) .squeeze() ) def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str: """ Convert the index of a frame to its name. Args: model: The model to consider. frame_index: The index of the frame. Returns: The name of the frame. """ n_l = model.number_of_links() n_f = len(model.frame_names()) exceptions.raise_value_error_if( condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(), msg="Invalid frame index '{idx}'", idx=frame_index, ) return model.kin_dyn_parameters.frame_parameters.name[ frame_index - model.number_of_links() ] @functools.partial(jax.jit, static_argnames=["frame_names"]) @js.common.named_scope def names_to_idxs( model: js.model.JaxSimModel, *, frame_names: Sequence[str] ) -> jax.Array: """ Convert a sequence of frame names to their corresponding indices. Args: model: The model to consider. frame_names: The names of the frames. Returns: The indices of the frames. """ return jnp.array( [name_to_idx(model=model, frame_name=name) for name in frame_names] ).astype(int) def idxs_to_names( model: js.model.JaxSimModel, *, frame_indices: Sequence[jtp.IntLike] ) -> tuple[str, ...]: """ Convert a sequence of frame indices to their corresponding names. Args: model: The model to consider. frame_indices: The indices of the frames. Returns: The names of the frames. """ return tuple(idx_to_name(model=model, frame_index=idx) for idx in frame_indices) # ========== # Frame APIs # ========== @jax.jit @js.common.named_scope def transform( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, frame_index: jtp.IntLike, ) -> jtp.Matrix: """ Compute the SE(3) transform from the world frame to the specified frame. Args: model: The model to consider. data: The data of the considered model. frame_index: The index of the frame for which the transform is requested. Returns: The 4x4 matrix representing the transform. """ n_l = model.number_of_links() n_f = len(model.frame_names()) exceptions.raise_value_error_if( condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(), msg="Invalid frame index '{idx}'", idx=frame_index, ) # Compute the necessary transforms. L = idx_of_parent_link(model=model, frame_index=frame_index) W_H_L = js.link.transform(model=model, data=data, link_index=L) # Get the static frame pose wrt the parent link. L_H_F = model.kin_dyn_parameters.frame_parameters.transform[ frame_index - model.number_of_links() ] # Combine the transforms computing the frame pose. return W_H_L @ L_H_F @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) @js.common.named_scope def velocity( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, frame_index: jtp.IntLike, output_vel_repr: VelRepr | None = None, ) -> jtp.Vector: """ Compute the 6D velocity of the frame. Args: model: The model to consider. data: The data of the considered model. frame_index: The index of the frame. output_vel_repr: The output velocity representation of the frame velocity. Returns: The 6D velocity of the frame in the specified velocity representation. """ n_l = model.number_of_links() n_f = model.number_of_frames() exceptions.raise_value_error_if( condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(), msg="Invalid frame index '{idx}'", idx=frame_index, ) output_vel_repr = ( output_vel_repr if output_vel_repr is not None else data.velocity_representation ) # Get the frame jacobian having I as input representation (taken from data) # and O as output representation, specified by the user (or taken from data). O_J_WF_I = jacobian( model=model, data=data, frame_index=frame_index, output_vel_repr=output_vel_repr, ) # Get the generalized velocity in the input velocity representation. I_ν = data.generalized_velocity # Compute the frame velocity in the output velocity representation. return O_J_WF_I @ I_ν @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) @js.common.named_scope def jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, frame_index: jtp.IntLike, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: r""" Compute the free-floating jacobian of the frame. Args: model: The model to consider. data: The data of the considered model. frame_index: The index of the frame. output_vel_repr: The output velocity representation of the free-floating jacobian. Returns: The :math:`6 \times (6+n)` free-floating jacobian of the frame. Note: The input representation of the free-floating jacobian is the active velocity representation. """ n_l = model.number_of_links() n_f = model.number_of_frames() exceptions.raise_value_error_if( condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(), msg="Invalid frame index '{idx}'", idx=frame_index, ) output_vel_repr = ( output_vel_repr if output_vel_repr is not None else data.velocity_representation ) # Get the index of the parent link. L = idx_of_parent_link(model=model, frame_index=frame_index) # Compute only the parent-link body Jacobian. L_J_WL = js.link.jacobian( model=model, data=data, link_index=L, output_vel_repr=VelRepr.Body, ) W_H_L = data._link_transforms[L] L_H_F = model.kin_dyn_parameters.frame_parameters.transform[ frame_index - model.number_of_links() ] L_p_F = L_H_F[0:3, 3] # Adjust the output representation. match output_vel_repr: case VelRepr.Inertial: W_X_L = Adjoint.from_rotation_and_translation( rotation=W_H_L[0:3, 0:3], translation=W_H_L[0:3, 3], ) W_J_WL = W_X_L @ L_J_WL O_J_WL_I = W_J_WL case VelRepr.Body: F_X_L = Adjoint.from_rotation_and_translation( rotation=L_H_F[0:3, 0:3], translation=L_p_F, inverse=True, ) F_J_WL = F_X_L @ L_J_WL O_J_WL_I = F_J_WL case VelRepr.Mixed: W_R_L = W_H_L[0:3, 0:3] FW_X_L = Adjoint.from_rotation_and_translation( rotation=W_R_L, translation=-W_R_L @ L_p_F, ) FW_J_WL = FW_X_L @ L_J_WL O_J_WL_I = FW_J_WL case _: raise ValueError(output_vel_repr) return O_J_WL_I @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) @js.common.named_scope def jacobian_derivative( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, frame_index: jtp.IntLike, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: r""" Compute the derivative of the free-floating jacobian of the frame. Args: model: The model to consider. data: The data of the considered model. frame_index: The index of the frame. output_vel_repr: The output velocity representation of the free-floating jacobian derivative. Returns: The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the frame. Note: The input representation of the free-floating jacobian derivative is the active velocity representation. """ n_l = model.number_of_links() n_f = len(model.frame_names()) exceptions.raise_value_error_if( condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(), msg="Invalid frame index '{idx}'", idx=frame_index, ) output_vel_repr = ( output_vel_repr if output_vel_repr is not None else data.velocity_representation ) # Get the index of the parent link. L = idx_of_parent_link(model=model, frame_index=frame_index) W_J_WL_I = js.link.jacobian( model=model, data=data, link_index=L, output_vel_repr=VelRepr.Inertial, ) W_J̇_WL_I = js.link.jacobian_derivative( model=model, data=data, link_index=L, output_vel_repr=VelRepr.Inertial, ) W_H_L = data._link_transforms[L] L_H_F = model.kin_dyn_parameters.frame_parameters.transform[ frame_index - model.number_of_links() ] W_H_F = W_H_L @ L_H_F # ===================================================== # Compute quantities to adjust the output representation # ===================================================== W_v_WF = W_J_WL_I @ data.generalized_velocity match output_vel_repr: case VelRepr.Inertial: O_X_W = jnp.eye(6, dtype=W_H_F.dtype) O_Ẋ_W = jnp.zeros((6, 6), dtype=W_H_F.dtype) case VelRepr.Body: O_X_W = Adjoint.from_rotation_and_translation( rotation=W_H_F[0:3, 0:3], translation=W_H_F[0:3, 3], inverse=True, ) O_Ẋ_W = -O_X_W @ Cross.vx(W_v_WF) case VelRepr.Mixed: O_X_W = Adjoint.from_rotation_and_translation( rotation=jnp.eye(3, dtype=W_H_F.dtype), translation=W_H_F[0:3, 3], inverse=True, ) FW_v_WF = O_X_W @ W_v_WF W_v_W_FW = FW_v_WF.at[3:6].set(jnp.zeros_like(FW_v_WF[3:6])) O_Ẋ_W = -O_X_W @ Cross.vx(W_v_W_FW) case _: raise ValueError(output_vel_repr) O_J̇_WF_I = O_Ẋ_W @ W_J_WL_I O_J̇_WF_I += O_X_W @ W_J̇_WL_I return O_J̇_WF_I ================================================ FILE: src/jaxsim/api/integrators.py ================================================ import dataclasses from collections.abc import Callable import jax import jax.numpy as jnp import jaxsim import jaxsim.api as js import jaxsim.typing as jtp from jaxsim.api.data import JaxSimModelData from jaxsim.math import Skew def semi_implicit_euler_integration( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, link_forces: jtp.Vector, joint_torques: jtp.Vector, ) -> JaxSimModelData: """Integrate the system state using the semi-implicit Euler method.""" with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): # Compute the system acceleration W_v̇_WB, s̈, contact_state_derivative = js.ode.system_acceleration( model=model, data=data, link_forces=link_forces, joint_torques=joint_torques, ) dt = model.time_step # Compute the new generalized velocity. new_generalized_acceleration = jnp.hstack([W_v̇_WB, s̈]) new_generalized_velocity = ( data.generalized_velocity + dt * new_generalized_acceleration ) # Extract the new base and joint velocities. W_v_B = new_generalized_velocity[0:6] ṡ = new_generalized_velocity[6:] # Compute the new base position and orientation. W_ω_WB = new_generalized_velocity[3:6] # To obtain the derivative of the base position, we need to subtract # the skew-symmetric matrix of the base angular velocity times the base position. # See: S. Traversaro and A. Saccon, “Multibody Dynamics Notation (Version 2), pg.9 W_ṗ_B = new_generalized_velocity[0:3] + Skew.wedge(W_ω_WB) @ data.base_position W_Q̇_B = jaxsim.math.Quaternion.derivative( quaternion=data.base_orientation, omega=W_ω_WB, omega_in_body_fixed=False, ).squeeze() W_p_B = data.base_position + dt * W_ṗ_B W_Q_B = data.base_orientation + dt * W_Q̇_B base_quaternion_norm = jaxsim.math.safe_norm(W_Q_B, axis=-1) W_Q_B = W_Q_B / jnp.where(base_quaternion_norm == 0, 1.0, base_quaternion_norm) s = data.joint_positions + dt * ṡ integrated_contact_state = jax.tree.map( lambda x, x_dot: x + dt * x_dot, data.contact_state, contact_state_derivative, ) data = dataclasses.replace( data, _base_quaternion=W_Q_B, _base_position=W_p_B, _joint_positions=s, _joint_velocities=ṡ, _base_linear_velocity=W_v_B[0:3], _base_angular_velocity=W_ω_WB, contact_state=integrated_contact_state, ) # Recompute kinematic caches for the new state. data = data.replace(model=model) return data def rk4_integration( model: js.model.JaxSimModel, data: JaxSimModelData, link_forces: jtp.Vector, joint_torques: jtp.Vector, ) -> JaxSimModelData: """Integrate the system state using the Runge-Kutta 4 method.""" dt = model.time_step def f(x) -> dict[str, jtp.Matrix]: with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): data_ti = data.replace(model=model, **x) return js.ode.system_dynamics( model=model, data=data_ti, link_forces=link_forces, joint_torques=joint_torques, ) base_quaternion_norm = jaxsim.math.safe_norm(data._base_quaternion, axis=-1) base_quaternion = data._base_quaternion / jnp.where( base_quaternion_norm == 0, 1.0, base_quaternion_norm ) x_t0 = dict( base_position=data._base_position, base_quaternion=base_quaternion, joint_positions=data._joint_positions, base_linear_velocity=data._base_linear_velocity, base_angular_velocity=data._base_angular_velocity, joint_velocities=data._joint_velocities, contact_state=data.contact_state, ) euler_mid = lambda x, dxdt: x + (0.5 * dt) * dxdt euler_fin = lambda x, dxdt: x + dt * dxdt k1 = f(x_t0) k2 = f(jax.tree.map(euler_mid, x_t0, k1)) k3 = f(jax.tree.map(euler_mid, x_t0, k2)) k4 = f(jax.tree.map(euler_fin, x_t0, k3)) # Average the slopes and compute the RK4 state derivative. average = lambda k1, k2, k3, k4: (k1 + 2 * k2 + 2 * k3 + k4) / 6 dxdt = jax.tree.map(average, k1, k2, k3, k4) # Integrate the dynamics x_tf = jax.tree.map(euler_fin, x_t0, dxdt) data_tf = dataclasses.replace( data, _base_position=x_tf["base_position"], _base_quaternion=x_tf["base_quaternion"], _joint_positions=x_tf["joint_positions"], _base_linear_velocity=x_tf["base_linear_velocity"], _base_angular_velocity=x_tf["base_angular_velocity"], _joint_velocities=x_tf["joint_velocities"], contact_state=x_tf["contact_state"], ) return data_tf.replace(model=model) def rk4fast_integration( model: js.model.JaxSimModel, data: JaxSimModelData, link_forces: jtp.Vector, joint_torques: jtp.Vector, ) -> JaxSimModelData: """ Integrate the system state using the Runge-Kutta 4 fast method. Note: This method is a faster version of the RK4 method, but it may not be as accurate. It computes the contact forces only once at the beginning of the integration step. """ dt = model.time_step if len(model.kin_dyn_parameters.contact_parameters.body) > 0: # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact # with the terrain. W_f_L_terrain, contact_state_derivative = js.contact.link_contact_forces( model=model, data=data, link_forces=link_forces, joint_torques=joint_torques, ) W_f_L_total = link_forces + W_f_L_terrain # Update the contact state data. This is necessary only for the contact models # that require propagation and integration of contact state. contact_state = model.contact_model.update_contact_state(contact_state_derivative) def f(x) -> dict[str, jtp.Matrix]: with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): data_ti = data.replace(model=model, **x) W_v̇_WB, s̈ = js.model.forward_dynamics_aba( model=model, data=data_ti, joint_forces=joint_torques, link_forces=W_f_L_total, ) W_ṗ_B, W_Q̇_B, ṡ = js.ode.system_position_dynamics( data=data, baumgarte_quaternion_regularization=1.0, ) return dict( base_position=W_ṗ_B, base_quaternion=W_Q̇_B, joint_positions=ṡ, base_linear_velocity=W_v̇_WB[0:3], base_angular_velocity=W_v̇_WB[3:6], joint_velocities=s̈, # The contact state is not updated here, as it is assumed to be constant. contact_state=data_ti.contact_state, ) base_quaternion_norm = jaxsim.math.safe_norm(data._base_quaternion, axis=-1) base_quaternion = data._base_quaternion / jnp.where( base_quaternion_norm == 0, 1.0, base_quaternion_norm ) x_t0 = dict( base_position=data._base_position, base_quaternion=base_quaternion, joint_positions=data._joint_positions, base_linear_velocity=data._base_linear_velocity, base_angular_velocity=data._base_angular_velocity, joint_velocities=data._joint_velocities, contact_state=contact_state, ) euler_mid = lambda x, dxdt: x + (0.5 * dt) * dxdt euler_fin = lambda x, dxdt: x + dt * dxdt k1 = f(x_t0) k2 = f(jax.tree.map(euler_mid, x_t0, k1)) k3 = f(jax.tree.map(euler_mid, x_t0, k2)) k4 = f(jax.tree.map(euler_fin, x_t0, k3)) # Average the slopes and compute the RK4 state derivative. average = lambda k1, k2, k3, k4: (k1 + 2 * k2 + 2 * k3 + k4) / 6 dxdt = jax.tree.map(average, k1, k2, k3, k4) # Integrate the dynamics x_tf = jax.tree.map(euler_fin, x_t0, dxdt) data_tf = dataclasses.replace( data, _base_position=x_tf["base_position"], _base_quaternion=x_tf["base_quaternion"], _joint_positions=x_tf["joint_positions"], _base_linear_velocity=x_tf["base_linear_velocity"], _base_angular_velocity=x_tf["base_angular_velocity"], _joint_velocities=x_tf["joint_velocities"], contact_state=x_tf["contact_state"], ) return data_tf.replace(model=model) _INTEGRATORS_MAP: dict[ js.model.IntegratorType, Callable[..., js.data.JaxSimModelData] ] = { js.model.IntegratorType.SemiImplicitEuler: semi_implicit_euler_integration, js.model.IntegratorType.RungeKutta4: rk4_integration, js.model.IntegratorType.RungeKutta4Fast: rk4fast_integration, } ================================================ FILE: src/jaxsim/api/joint.py ================================================ import functools from collections.abc import Sequence import jax import jax.numpy as jnp import jaxsim.api as js import jaxsim.typing as jtp from jaxsim import exceptions # ======================= # Index-related functions # ======================= @functools.partial(jax.jit, static_argnames="joint_name") @js.common.named_scope def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int: """ Convert the name of a joint to its index. Args: model: The model to consider. joint_name: The name of the joint. Returns: The index of the joint. """ if joint_name not in model.joint_names(): raise ValueError(f"Joint '{joint_name}' not found in the model.") # Note: the index of the joint for RBDAs starts from 1, but the index for # accessing the right element starts from 0. Therefore, there is a -1. return ( jnp.array( model.kin_dyn_parameters.joint_model.joint_names.index(joint_name) - 1 ) .astype(int) .squeeze() ) def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str: """ Convert the index of a joint to its name. Args: model: The model to consider. joint_index: The index of the joint. Returns: The name of the joint. """ exceptions.raise_value_error_if( condition=joint_index < 0, msg="Invalid joint index '{idx}'", idx=joint_index, ) return model.kin_dyn_parameters.joint_model.joint_names[joint_index + 1] @functools.partial(jax.jit, static_argnames="joint_names") @js.common.named_scope def names_to_idxs( model: js.model.JaxSimModel, *, joint_names: Sequence[str] ) -> jax.Array: """ Convert a sequence of joint names to their corresponding indices. Args: model: The model to consider. joint_names: The names of the joints. Returns: The indices of the joints. """ return jnp.array( [name_to_idx(model=model, joint_name=name) for name in joint_names], ).astype(int) def idxs_to_names( model: js.model.JaxSimModel, *, joint_indices: Sequence[jtp.IntLike] | jtp.VectorLike, ) -> tuple[str, ...]: """ Convert a sequence of joint indices to their corresponding names. Args: model: The model to consider. joint_indices: The indices of the joints. Returns: The names of the joints. """ return tuple(idx_to_name(model=model, joint_index=idx) for idx in joint_indices) # ============ # Joint limits # ============ @jax.jit def position_limit( model: js.model.JaxSimModel, *, joint_index: jtp.IntLike ) -> tuple[jtp.Float, jtp.Float]: """ Get the position limits of a joint. Args: model: The model to consider. joint_index: The index of the joint. Returns: The position limits of the joint. """ if model.number_of_joints() == 0: return jnp.empty(0).astype(float), jnp.empty(0).astype(float) exceptions.raise_value_error_if( condition=jnp.array( [joint_index < 0, joint_index >= model.number_of_joints()] ).any(), msg="Invalid joint index '{idx}'", idx=joint_index, ) s_min = jnp.atleast_1d( model.kin_dyn_parameters.joint_parameters.position_limits_min )[joint_index] s_max = jnp.atleast_1d( model.kin_dyn_parameters.joint_parameters.position_limits_max )[joint_index] return s_min.astype(float), s_max.astype(float) @functools.partial(jax.jit, static_argnames=["joint_names"]) @js.common.named_scope def position_limits( model: js.model.JaxSimModel, *, joint_names: Sequence[str] | None = None ) -> tuple[jtp.Vector, jtp.Vector]: """ Get the position limits of a list of joint. Args: model: The model to consider. joint_names: The names of the joints. Returns: The position limits of the joints. """ joint_idxs = ( names_to_idxs(joint_names=joint_names, model=model) if joint_names is not None else jnp.arange(model.number_of_joints()) ) if len(joint_idxs) == 0: return jnp.empty(0).astype(float), jnp.empty(0).astype(float) s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min[joint_idxs] s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max[joint_idxs] return s_min.astype(float), s_max.astype(float) # ====================== # Random data generation # ====================== @functools.partial(jax.jit, static_argnames=["joint_names"]) @js.common.named_scope def random_joint_positions( model: js.model.JaxSimModel, *, joint_names: Sequence[str] | None = None, key: jax.Array | None = None, ) -> jtp.Vector: """ Generate random joint positions. Args: model: The model to consider. joint_names: The names of the considered joints (all if None). key: The random key (initialized from seed 0 if None). Note: If the joint range or revolute joints is larger than 2π, their joint positions will be sampled from an interval of size 2π. Returns: The random joint positions. """ # Consider the key corresponding to a zero seed if it was not passed. key = key if key is not None else jax.random.PRNGKey(seed=0) # Get the joint limits parsed from the model description. s_min, s_max = position_limits(model=model, joint_names=joint_names) # Get the joint indices. # Note that it will trigger an exception if the given `joint_names` are not valid. joint_names = joint_names if joint_names is not None else model.joint_names() joint_indices = ( names_to_idxs(model=model, joint_names=joint_names) if joint_names is not None else jnp.arange(model.number_of_joints()) ) from jaxsim.parsers.descriptions.joint import JointType # Filter for revolute joints. is_revolute = jnp.where( jnp.array(model.kin_dyn_parameters.joint_model.joint_types[1:])[joint_indices] == JointType.Revolute, True, False, ) # Shorthand for π. π = jnp.pi # Filter for revolute with full range (or continuous). is_revolute_full_range = jnp.logical_and(is_revolute, s_max - s_min >= 2 * π) # Clip the lower limit to -π if the joint range is larger than [-π, π]. s_min = jnp.where( jnp.logical_and( is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π) ), -π, s_min, ) # Clip the upper limit to +π if the joint range is larger than [-π, π]. s_max = jnp.where( jnp.logical_and( is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π) ), π, s_max, ) # Shift the lower limit if the upper limit is smaller than +π. s_min = jnp.where( jnp.logical_and(is_revolute_full_range, s_max < π), s_max - 2 * π, s_min, ) # Shift the upper limit if the lower limit is larger than -π. s_max = jnp.where( jnp.logical_and(is_revolute_full_range, s_min > -π), s_min + 2 * π, s_max, ) # Sample the joint positions. s_random = jax.random.uniform( minval=s_min, maxval=s_max, key=key, shape=s_min.shape, ) return s_random ================================================ FILE: src/jaxsim/api/kin_dyn_parameters.py ================================================ from __future__ import annotations import dataclasses from itertools import starmap from typing import ClassVar import jax.lax import jax.numpy as jnp import jax_dataclasses import numpy as np import numpy.typing as npt from jax_dataclasses import Static import jaxsim import jaxsim.typing as jtp from jaxsim.math import Inertia, JointModel, supported_joint_motion from jaxsim.math.adjoint import Adjoint from jaxsim.parsers.descriptions import JointDescription, JointType, ModelDescription from jaxsim.utils import HashedNumpyArray, JaxsimDataclass @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) class KinDynParameters(JaxsimDataclass): r""" Class storing the kinematic and dynamic parameters of a model. Attributes: link_names: The names of the links. parent_array: The parent array :math:`\lambda(i)` of the model. support_body_array_bool: The boolean support parent array :math:`\kappa_{b}(i)` of the model. link_parameters: The parameters of the links. frame_parameters: The parameters of the frames. contact_parameters: The parameters of the collidable points. joint_model: The joint model of the model. joint_parameters: The parameters of the joints. hw_link_metadata: The hardware parameters of the model links. constraints: The kinematic constraints of the model. They can be used only with Relaxed-Rigid contact model. """ # Static link_names: Static[tuple[str]] _parent_array: Static[HashedNumpyArray] _support_body_array_bool: Static[HashedNumpyArray] _motion_subspaces: Static[HashedNumpyArray] # Tree level structure for parallel algorithms. # level_nodes: (n_levels, max_width) array of link indices at each depth level, # padded with 0 for levels with fewer nodes than max_width. # level_mask: (n_levels, max_width) boolean mask, True for real nodes. _level_nodes: Static[HashedNumpyArray] _level_mask: Static[HashedNumpyArray] # Links link_parameters: LinkParameters # Contacts contact_parameters: ContactParameters # Frames frame_parameters: FrameParameters # Joints joint_model: JointModel joint_parameters: JointParameters | None # Model hardware parameters hw_link_metadata: HwLinkMetadata | None = dataclasses.field(default=None) # Kinematic constraints constraints: ConstraintMap | None = dataclasses.field(default=None) @property def motion_subspaces(self) -> jtp.Matrix: r""" Return the motion subspaces :math:`\mathbf{S}(s)` of the joints. """ return self._motion_subspaces.get() @property def parent_array(self) -> jtp.Vector: r""" Return the parent array :math:`\lambda(i)` of the model. """ return self._parent_array.get() @property def support_body_array_bool(self) -> jtp.Matrix: r""" Return the boolean support parent array :math:`\kappa_{b}(i)` of the model. """ return self._support_body_array_bool.get() @property def level_nodes(self) -> jtp.Matrix: r""" Return the tree level nodes array of shape ``(n_levels, max_width)``. Each row contains the link indices at the corresponding depth level, padded with 0 for levels with fewer nodes than ``max_width``. """ return self._level_nodes.get() @property def level_mask(self) -> jtp.Matrix: r""" Return the tree level mask of shape ``(n_levels, max_width)``. Each entry is ``True`` for real nodes and ``False`` for padding. """ return self._level_mask.get() @staticmethod def _compute_tree_levels( parent_array: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: """ Compute the tree level decomposition from a parent array. Args: parent_array: Array of shape ``(n,)`` where ``parent_array[i]`` is the parent of link ``i``. ``parent_array[0] == -1`` for the root. Returns: A tuple ``(level_nodes, level_mask)`` where: - ``level_nodes`` has shape ``(n_levels, max_width)`` with link indices at each depth level (padded with 0). - ``level_mask`` has shape ``(n_levels, max_width)`` with ``True`` for real nodes. """ import numpy as np n = len(parent_array) # Compute depth of each node. depth = np.zeros(n, dtype=int) for i in range(1, n): depth[i] = depth[parent_array[i]] + 1 max_depth = int(depth.max()) if n > 0 else 0 n_levels = max_depth + 1 # Group nodes by depth level. levels: list[list[int]] = [[] for _ in range(n_levels)] for i in range(n): levels[depth[i]].append(i) max_width = max(len(lev) for lev in levels) if levels else 1 # Build padded arrays. level_nodes = np.zeros((n_levels, max_width), dtype=int) level_mask = np.zeros((n_levels, max_width), dtype=bool) for d, lev in enumerate(levels): for j, node_idx in enumerate(lev): level_nodes[d, j] = node_idx level_mask[d, j] = True return level_nodes, level_mask @staticmethod def build( model_description: ModelDescription, constraints: ConstraintMap | None ) -> KinDynParameters: """ Construct the kinematic and dynamic parameters of the model. Args: model_description: The parsed model description to consider. constraints: An object of type ConstraintMap specifying the kinematic constraint of the model. Returns: The kinematic and dynamic parameters of the model. Note: This class is meant to ease the management of parametric models in an automatic differentiation context. """ # Extract the links ordered by their index. # The link index corresponds to the body index ∈ [0, num_bodies - 1]. ordered_links = sorted( list(model_description.links_dict.values()), key=lambda l: l.index, ) # Extract the joints ordered by their index. # The joint index matches the index of its child link, therefore it starts # from 1. Keep this in mind since this 1-indexing might introduce bugs. ordered_joints = sorted( list(model_description.joints_dict.values()), key=lambda j: j.index, ) # ================ # Links properties # ================ # Create a list of link parameters objects. link_parameters_list = [ LinkParameters.build_from_spatial_inertia(index=link.index, M=link.inertia) for link in ordered_links ] # Create a vectorized object of link parameters. link_parameters = jax.tree.map(lambda *l: jnp.stack(l), *link_parameters_list) # ================= # Joints properties # ================= # Create a list of joint parameters objects. joint_parameters_list = [ JointParameters.build_from_joint_description(joint_description=joint) for joint in ordered_joints ] # Create a vectorized object of joint parameters. joint_parameters = ( jax.tree.map(lambda *l: jnp.stack(l), *joint_parameters_list) if ordered_joints else JointParameters( index=jnp.array([], dtype=int), friction_static=jnp.array([], dtype=float), friction_viscous=jnp.array([], dtype=float), position_limits_min=jnp.array([], dtype=float), position_limits_max=jnp.array([], dtype=float), position_limit_spring=jnp.array([], dtype=float), position_limit_damper=jnp.array([], dtype=float), ) ) # Create an object that defines the joint model (parent-to-child transforms). joint_model = JointModel.build(description=model_description) # =================== # Contacts properties # =================== # Create the object storing the parameters of collidable points. # Note that, contrarily to LinkParameters and JointsParameters, this object # is not created with vmap. This is because the "body" attribute of the object # must be Static for JIT-related reasons, and tree_map would not consider it # as a leaf. contact_parameters = ContactParameters.build_from( model_description=model_description ) # ================= # Frames properties # ================= # Create the object storing the parameters of frames. # Note that, contrarily to LinkParameters and JointsParameters, this object # is not created with vmap. This is because the "name" attribute of the object # must be Static for JIT-related reasons, and tree_map would not consider it # as a leaf. frame_parameters = FrameParameters.build_from( model_description=model_description ) # =============== # Tree properties # =============== # Build the parent array λ(i) of the model. # Note: the parent of the base link is not set since it's not defined. parent_array_dict = { link.index: model_description.links_dict[link.parent_name].index for link in ordered_links if link.parent_name is not None } parent_array = jnp.array([-1, *list(parent_array_dict.values())], dtype=int) # Instead of building the support parent array κ(i) for each link of the model, # that has a variable length depending on the number of links connecting the # root to the i-th link, we build the corresponding boolean version. # Given a link index i, the boolean support parent array κb(i) is an array # with the same number of elements of λ(i) having the i-th element set to True # if the i-th link is in the support parent array κ(i), False otherwise. # We store the boolean κb(i) as static attribute of the PyTree so that # algorithms that need to access it can be jit-compiled. def κb(link_index: jtp.IntLike) -> jtp.Vector: κb = jnp.zeros(len(ordered_links), dtype=bool) carry0 = κb, link_index def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]: κb, active_link_index = carry κb, active_link_index = jax.lax.cond( pred=(i == active_link_index), false_fun=lambda: (κb, active_link_index), true_fun=lambda: ( κb.at[active_link_index].set(True), parent_array[active_link_index], ), ) return (κb, active_link_index), None (κb, _), _ = jax.lax.scan( f=scan_body, init=carry0, xs=jnp.flip(jnp.arange(start=0, stop=len(ordered_links))), ) return κb support_body_array_bool = jax.vmap(κb)( jnp.arange(start=0, stop=len(ordered_links)) ) def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike: S = { JointType.Fixed: np.zeros(shape=(6, 1)), JointType.Revolute: np.vstack(np.hstack([np.zeros(3), axis.axis])), JointType.Prismatic: np.vstack(np.hstack([axis.axis, np.zeros(3)])), } return S[joint_type] S_J = jnp.array( list( starmap( motion_subspace, zip( joint_model.joint_types[1:], joint_model.joint_axis, strict=True ), ) ) if len(joint_model.joint_axis) != 0 else jnp.empty((0, 6, 1)) ) motion_subspaces = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J]) # ==================== # Tree level structure # ==================== parent_array_np = np.array([-1, *list(parent_array_dict.values())], dtype=int) level_nodes, level_mask = KinDynParameters._compute_tree_levels(parent_array_np) # =========== # Constraints # =========== constraints = ConstraintMap() if constraints is None else constraints # ================================= # Build and return KinDynParameters # ================================= return KinDynParameters( link_names=tuple(l.name for l in ordered_links), _parent_array=HashedNumpyArray(array=parent_array), _support_body_array_bool=HashedNumpyArray(array=support_body_array_bool), _motion_subspaces=HashedNumpyArray(array=motion_subspaces), _level_nodes=HashedNumpyArray(array=level_nodes), _level_mask=HashedNumpyArray(array=level_mask), link_parameters=link_parameters, joint_model=joint_model, joint_parameters=joint_parameters, contact_parameters=contact_parameters, frame_parameters=frame_parameters, constraints=constraints, ) def __eq__(self, other: KinDynParameters) -> bool: if not isinstance(other, KinDynParameters): return False return hash(self) == hash(other) def __hash__(self) -> int: return hash( ( hash(self.number_of_links()), hash(self.number_of_joints()), hash(self.frame_parameters.name), hash(self.frame_parameters.body), hash(self._parent_array), hash(self._support_body_array_bool), ) ) # ============================= # Helpers to extract parameters # ============================= def number_of_links(self) -> int: """ Return the number of links of the model. Returns: The number of links of the model. """ return len(self.link_names) def number_of_joints(self) -> int: """ Return the number of joints of the model. Returns: The number of joints of the model. """ return len(self.joint_model.joint_names) - 1 def number_of_frames(self) -> int: """ Return the number of frames of the model. Returns: The number of frames of the model. """ return len(self.frame_parameters.name) def support_body_array(self, link_index: jtp.IntLike) -> jtp.Vector: r""" Return the support parent array :math:`\kappa(i)` of a link. Args: link_index: The index of the link. Returns: The support parent array :math:`\kappa(i)` of the link. Note: This method returns a variable-length vector. In jit-compiled functions, it's better to use the (static) boolean version `support_body_array_bool`. """ return jnp.array( jnp.where(self.support_body_array_bool[link_index])[0], dtype=int ) # ======================== # Quantities used by RBDAs # ======================== @jax.jit def links_spatial_inertia(self) -> jtp.Array: """ Return the spatial inertia of all links of the model. Returns: The spatial inertia of all links of the model. """ return jax.vmap(LinkParameters.spatial_inertia)(self.link_parameters) @jax.jit def tree_transforms(self) -> jtp.Array: r""" Return the tree transforms of the model. Returns: The transforms :math:`{}^{\text{pre}(i)} H_{\lambda(i)}` of all joints of the model. """ pre_Xi_λ = jax.vmap( lambda i: self.joint_model.parent_H_predecessor(joint_index=i) .inverse() .adjoint() )(jnp.arange(1, self.number_of_joints() + 1)) return jnp.vstack( [ jnp.zeros(shape=(1, 6, 6), dtype=float), pre_Xi_λ, ] ) @jax.jit def joint_transforms( self, joint_positions: jtp.VectorLike, base_transform: jtp.MatrixLike ) -> jtp.Array: r""" Return the transforms of the joints. Args: joint_positions: The joint positions. base_transform: The homogeneous matrix defining the base pose. Returns: The stacked transforms :math:`{}^{i} \mathbf{H}_{\lambda(i)}(s)` of each joint. """ # Rename the base transform. W_H_B = base_transform # Extract the parent-to-predecessor fixed transforms of the joints. λ_H_pre = jnp.vstack( [ jnp.eye(4)[jnp.newaxis], self.joint_model.λ_H_pre[1 : 1 + self.number_of_joints()], ] ) if self.number_of_joints() == 0: pre_H_suc_J = jnp.empty((0, 4, 4)) else: pre_H_suc_J = jax.vmap(supported_joint_motion)( joint_types=jnp.array(self.joint_model.joint_types[1:]).astype(int), joint_positions=jnp.array(joint_positions), joint_axes=jnp.array([j.axis for j in self.joint_model.joint_axis]), ) # Extract the transforms and motion subspaces of the joints. # We stack the base transform W_H_B at index 0, and a dummy motion subspace # for either the fixed or free-floating joint connecting the world to the base. pre_H_suc = jnp.vstack([W_H_B[jnp.newaxis, ...], pre_H_suc_J]) # Extract the successor-to-child fixed transforms. # Note that here we include also the index 0 since suc_H_child[0] stores the # optional pose of the base link w.r.t. the root frame of the model. # This is supported by SDF when the base link element is defined. suc_H_i = self.joint_model.suc_H_i[jnp.arange(0, 1 + self.number_of_joints())] # Compute the overall transforms from the parent to the child of each joint by # composing all the components of our joint model. i_X_λ = jax.vmap( lambda λ_Hi_pre, pre_Hi_suc, suc_Hi_i: Adjoint.from_transform( transform=λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i, inverse=True ) )(λ_H_pre, pre_H_suc, suc_H_i) return i_X_λ # ============================ # Helpers to update parameters # ============================ def set_link_mass( self, link_index: jtp.IntLike, mass: jtp.FloatLike ) -> KinDynParameters: """ Set the mass of a link. Args: link_index: The index of the link. mass: The mass of the link. Returns: The updated kinematic and dynamic parameters of the model. """ link_parameters = self.link_parameters.replace( mass=self.link_parameters.mass.at[link_index].set(mass) ) return self.replace(link_parameters=link_parameters) def set_link_inertia( self, link_index: jtp.IntLike, inertia: jtp.MatrixLike ) -> KinDynParameters: r""" Set the inertia tensor of a link. Args: link_index: The index of the link. inertia: The :math:`3 \times 3` inertia tensor of the link. Returns: The updated kinematic and dynamic parameters of the model. """ inertia_elements = LinkParameters.flatten_inertia_tensor(I=inertia) link_parameters = self.link_parameters.replace( inertia_elements=self.link_parameters.inertia_elements.at[link_index].set( inertia_elements ) ) return self.replace(link_parameters=link_parameters) @jax_dataclasses.pytree_dataclass class JointParameters(JaxsimDataclass): """ Class storing the parameters of a joint. Attributes: index: The index of the joint. friction_static: The static friction of the joint. friction_viscous: The viscous friction of the joint. position_limits_min: The lower position limit of the joint. position_limits_max: The upper position limit of the joint. position_limit_spring: The spring constant of the position limit. position_limit_damper: The damper constant of the position limit. Note: This class is used inside KinDynParameters to store the vectorized set of joint parameters. """ index: jtp.Int friction_static: jtp.Float friction_viscous: jtp.Float position_limits_min: jtp.Float position_limits_max: jtp.Float position_limit_spring: jtp.Float position_limit_damper: jtp.Float @staticmethod def build_from_joint_description( joint_description: JointDescription, ) -> JointParameters: """ Build a JointParameters object from a joint description. Args: joint_description: The joint description to consider. Returns: The JointParameters object. """ s_min = joint_description.position_limit[0] s_max = joint_description.position_limit[1] position_limits_min = jnp.minimum(s_min, s_max) position_limits_max = jnp.maximum(s_min, s_max) friction_static = jnp.array(joint_description.friction_static).squeeze() friction_viscous = jnp.array(joint_description.friction_viscous).squeeze() position_limit_spring = jnp.array( joint_description.position_limit_spring ).squeeze() position_limit_damper = jnp.array( joint_description.position_limit_damper ).squeeze() return JointParameters( index=jnp.array(joint_description.index).squeeze().astype(int), friction_static=friction_static.astype(float), friction_viscous=friction_viscous.astype(float), position_limits_min=position_limits_min.astype(float), position_limits_max=position_limits_max.astype(float), position_limit_spring=position_limit_spring.astype(float), position_limit_damper=position_limit_damper.astype(float), ) @jax_dataclasses.pytree_dataclass class LinkParameters(JaxsimDataclass): r""" Class storing the parameters of a link. Attributes: index: The index of the link. mass: The mass of the link. inertia_elements: The unique elements of the :math:`3 \times 3` inertia tensor of the link. center_of_mass: The translation :math:`{}^L \mathbf{p}_{\text{CoM}}` between the origin of the link frame and the link's center of mass, expressed in the coordinates of the link frame. Note: This class is used inside KinDynParameters to store the vectorized set of link parameters. """ index: jtp.Int mass: jtp.Float center_of_mass: jtp.Vector inertia_elements: jtp.Vector @staticmethod def build_from_spatial_inertia(index: jtp.IntLike, M: jtp.Matrix) -> LinkParameters: r""" Build a LinkParameters object from a :math:`6 \times 6` spatial inertia matrix. Args: index: The index of the link. M: The :math:`6 \times 6` spatial inertia matrix of the link. Returns: The LinkParameters object. """ # Extract the link parameters from the 6D spatial inertia. m, L_p_CoM, I_CoM = Inertia.to_params(M=M) # Extract only the necessary elements of the inertia tensor. inertia_elements = I_CoM[jnp.triu_indices(3)] return LinkParameters( index=jnp.array(index).squeeze().astype(int), mass=jnp.array(m).squeeze().astype(float), center_of_mass=jnp.atleast_1d(jnp.array(L_p_CoM).squeeze()).astype(float), inertia_elements=jnp.atleast_1d(inertia_elements.squeeze()).astype(float), ) @staticmethod def build_from_inertial_parameters( index: jtp.IntLike, m: jtp.FloatLike, I: jtp.MatrixLike, c: jtp.VectorLike ) -> LinkParameters: r""" Build a LinkParameters object from the inertial parameters of a link. Args: index: The index of the link. m: The mass of the link. I: The :math:`3 \times 3` inertia tensor of the link. c: The translation between the link frame and the link's center of mass. Returns: The LinkParameters object. """ # Extract only the necessary elements of the inertia tensor. inertia_elements = I[jnp.triu_indices(3)] return LinkParameters( index=jnp.array(index).squeeze().astype(int), mass=jnp.array(m).squeeze().astype(float), center_of_mass=jnp.atleast_1d(c.squeeze()).astype(float), inertia_elements=jnp.atleast_1d(inertia_elements.squeeze()).astype(float), ) @staticmethod def build_from_flat_parameters( index: jtp.IntLike, parameters: jtp.VectorLike ) -> LinkParameters: """ Build a LinkParameters object from a flat vector of parameters. Args: index: The index of the link. parameters: The flat vector of parameters. Returns: The LinkParameters object. """ index = jnp.array(index).squeeze().astype(int) m = jnp.array(parameters[0]).squeeze().astype(float) c = jnp.atleast_1d(parameters[1:4].squeeze()).astype(float) inertia_elements = jnp.atleast_1d(parameters[4:].squeeze()).astype(float) return LinkParameters( index=index, mass=m, inertia_elements=inertia_elements, center_of_mass=c ) @staticmethod def flat_parameters(params: LinkParameters) -> jtp.Vector: """ Return the parameters of a link as a flat vector. Args: params: The link parameters. Returns: The parameters of the link as a flat vector. """ return ( jnp.hstack( [ params.mass, params.center_of_mass.squeeze(), params.inertia_elements, ] ) .squeeze() .astype(float) ) @staticmethod def inertia_tensor(params: LinkParameters) -> jtp.Matrix: r""" Return the :math:`3 \times 3` inertia tensor of a link. Args: params: The link parameters. Returns: The :math:`3 \times 3` inertia tensor of the link. """ return LinkParameters.unflatten_inertia_tensor( inertia_elements=params.inertia_elements ) @staticmethod def spatial_inertia(params: LinkParameters) -> jtp.Matrix: r""" Return the :math:`6 \times 6` spatial inertia matrix of a link. Args: params: The link parameters. Returns: The :math:`6 \times 6` spatial inertia matrix of the link. """ return Inertia.to_sixd( mass=params.mass, I=LinkParameters.inertia_tensor(params), com=params.center_of_mass, ) @staticmethod def flatten_inertia_tensor(I: jtp.Matrix) -> jtp.Vector: r""" Flatten a :math:`3 \times 3` inertia tensor into a vector of unique elements. Args: I: The :math:`3 \times 3` inertia tensor. Returns: The vector of unique elements of the inertia tensor. """ return jnp.atleast_1d(I[jnp.triu_indices(3)].squeeze()) @staticmethod def unflatten_inertia_tensor(inertia_elements: jtp.Vector) -> jtp.Matrix: r""" Unflatten a vector of unique elements into a :math:`3 \times 3` inertia tensor. Args: inertia_elements: The vector of unique elements of the inertia tensor. Returns: The :math:`3 \times 3` inertia tensor. """ I = jnp.zeros([3, 3]).at[jnp.triu_indices(3)].set(inertia_elements.squeeze()) return jnp.atleast_2d(jnp.where(I, I, I.T)).astype(float) @jax_dataclasses.pytree_dataclass class ContactParameters(JaxsimDataclass): """ Class storing the contact parameters of a model. Attributes: body: A tuple of integers representing, for each collidable point, the index of the body (link) to which it is rigidly attached to. point: The translations between the link frame and the collidable point, expressed in the coordinates of the parent link frame. enabled: A tuple of booleans representing, for each collidable point, whether it is enabled or not in contact models. Note: Contrarily to LinkParameters and JointParameters, this class is not meant to be created with vmap. This is because the `body` attribute must be `Static`. """ body: Static[tuple[int, ...]] = dataclasses.field(default_factory=tuple) point: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.array([])) enabled: Static[tuple[bool, ...]] = dataclasses.field(default_factory=tuple) @property def indices_of_enabled_collidable_points(self) -> npt.NDArray: """ Return the indices of the enabled collidable points. """ return np.where(np.array(self.enabled))[0] @staticmethod def build_from(model_description: ModelDescription) -> ContactParameters: """ Build a ContactParameters object from a model description. Args: model_description: The model description to consider. Returns: The ContactParameters object. """ if len(model_description.collision_shapes) == 0: return ContactParameters() # Get all the links so that we can take their updated index. links_dict = {link.name: link for link in model_description} # Get all the enabled collidable points of the model. collidable_points = model_description.all_enabled_collidable_points() # Extract the positions L_p_C of the collidable points w.r.t. the link frames # they are rigidly attached to. points = jnp.vstack([cp.position for cp in collidable_points]) # Extract the indices of the links to which the collidable points are rigidly # attached to. link_index_of_points = tuple( links_dict[cp.parent_link.name].index for cp in collidable_points ) # Build the ContactParameters object. cp = ContactParameters( point=points, body=link_index_of_points, enabled=tuple(True for _ in link_index_of_points), ) assert cp.point.shape[1] == 3, cp.point.shape[1] assert cp.point.shape[0] == len(cp.body), cp.point.shape[0] return cp @jax_dataclasses.pytree_dataclass class FrameParameters(JaxsimDataclass): """ Class storing the frame parameters of a model. Attributes: name: A tuple of strings defining the frame names. body: A vector of integers representing, for each frame, the index of the body (link) to which it is rigidly attached to. transform: The transforms of the frames w.r.t. their parent link. Note: Contrarily to LinkParameters and JointParameters, this class is not meant to be created with vmap. This is because the `name` attribute must be `Static`. """ name: Static[tuple[str, ...]] = dataclasses.field(default_factory=tuple) body: Static[tuple[int, ...]] = dataclasses.field(default_factory=tuple) transform: jtp.Array = dataclasses.field(default_factory=lambda: jnp.array([])) @staticmethod def build_from(model_description: ModelDescription) -> FrameParameters: """ Build a FrameParameters object from a model description. Args: model_description: The model description to consider. Returns: The FrameParameters object. """ if len(model_description.frames) == 0: return FrameParameters() # Extract the frame names. names = tuple(frame.name for frame in model_description.frames) # For each frame, extract the index of the link to which it is attached to. parent_link_index_of_frames = tuple( model_description.links_dict[frame.parent_name].index for frame in model_description.frames ) # For each frame, extract the transform w.r.t. its parent link. transforms = jnp.atleast_3d( jnp.stack([frame.pose for frame in model_description.frames]) ) # Build the FrameParameters object. fp = FrameParameters( name=names, transform=transforms.astype(float), body=parent_link_index_of_frames, ) assert fp.transform.shape[1:] == (4, 4), fp.transform.shape[1:] assert fp.transform.shape[0] == len(fp.body), fp.transform.shape[0] return fp @dataclasses.dataclass(frozen=True) class LinkParametrizableShape: """ Enum-like class listing the supported shapes for HW parametrization. """ Unsupported: ClassVar[int] = -1 Box: ClassVar[int] = 0 Cylinder: ClassVar[int] = 1 Sphere: ClassVar[int] = 2 Mesh: ClassVar[int] = 3 @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) class HwLinkMetadata(JaxsimDataclass): """ Class storing the hardware parameters of a link. Attributes: link_shape: The shape of the link. 0 = box, 1 = cylinder, 2 = sphere, 3 = mesh, -1 = unsupported. geometry: Shape parameters used by HW parametrization. box: [lx,ly,lz], cylinder: [r,l,0], sphere: [r,0,0], mesh: cumulative anisotropic scale factors [sx,sy,sz] (initialized to [1,1,1]). density: The density of the link. L_H_G: The homogeneous transformation matrix from the link frame to the CoM frame G. L_H_vis: The homogeneous transformation matrix from the link frame to the visual frame. L_H_pre_mask: The mask indicating the link's child joint indices. L_H_pre: The homogeneous transforms for child joints. mesh_moments: Precomputed volumetric moments for mesh shapes (n_links x 13). Each row stores [V_ref, com_x, com_y, com_z, Σ_00..Σ_22] where V_ref is the reference volume, com is the volumetric center of mass, and Σ is the volumetric covariance matrix at the origin. Zero for non-mesh links. mesh_vertices: The original centered mesh vertices (Nx3) for mesh shapes, None otherwise. mesh_faces: The mesh triangle faces (Mx3 integer indices) for mesh shapes, None otherwise. mesh_offset: The original mesh centroid offset (3D vector) for mesh shapes, None otherwise. mesh_uri: The path to the mesh file for reference, None otherwise. """ link_shape: jtp.Vector geometry: jtp.Vector density: jtp.Float L_H_G: jtp.Matrix L_H_vis: jtp.Matrix L_H_pre_mask: jtp.Vector L_H_pre: jtp.Matrix mesh_moments: jtp.Matrix mesh_vertices: Static[tuple[HashedNumpyArray | None, ...] | None] mesh_faces: Static[tuple[HashedNumpyArray | None, ...] | None] mesh_offset: Static[tuple[HashedNumpyArray | None, ...] | None] mesh_uri: Static[tuple[str | None, ...] | None] @classmethod def empty(cls) -> HwLinkMetadata: """Return hardware metadata representing the absence of links.""" return cls( link_shape=jnp.array([], dtype=int), geometry=jnp.array([], dtype=float), density=jnp.array([], dtype=float), L_H_G=jnp.array([], dtype=float), L_H_vis=jnp.array([], dtype=float), L_H_pre_mask=jnp.array([], dtype=bool), L_H_pre=jnp.array([], dtype=float), mesh_moments=jnp.zeros((0, 13), dtype=float), mesh_vertices=None, mesh_faces=None, mesh_offset=None, mesh_uri=None, ) @staticmethod def compute_mesh_inertia( vertices: jtp.Matrix, faces: jtp.Matrix, density: jtp.Float ) -> tuple[jtp.Float, jtp.Vector, jtp.Matrix]: """ Compute mass, center of mass, and inertia tensor from mesh geometry. Uses the divergence theorem to compute volumetric properties by integrating over tetrahedra formed between the mesh surface and the origin. Args: vertices: Mesh vertices (Nx3) in the link frame, should be centered. faces: Triangle face indices (Mx3), integer indices into vertices array. density: Material density. Returns: A tuple containing the computed mass, the CoM position and the 3x3 inertia tensor at the CoM. """ # Extract triangles from vertices using face indices triangles = vertices[faces.astype(int)] A, B, C = triangles[:, 0], triangles[:, 1], triangles[:, 2] # Compute signed volume of tetrahedra relative to origin # vol = 1/6 * (A . (B x C)) tetrahedron_volumes = jnp.sum(A * jnp.cross(B, C), axis=1) / 6.0 total_signed_volume = jnp.sum(tetrahedron_volumes) # Normalize the global winding sign so positive density yields non-negative mass. orientation_sign = jnp.where(total_signed_volume < 0, -1.0, 1.0) tetrahedron_volumes = tetrahedron_volumes * orientation_sign total_volume = jnp.sum(tetrahedron_volumes) eps = jnp.asarray(1e-12, dtype=total_volume.dtype) is_valid_volume = jnp.abs(total_volume) > eps safe_total_volume = jnp.where(is_valid_volume, total_volume, 1.0) mass = jnp.where(is_valid_volume, total_volume * density, 0.0) # Compute center of mass tet_coms = (A + B + C) / 4.0 com_position = jnp.where( is_valid_volume, jnp.sum(tet_coms * tetrahedron_volumes[:, None], axis=0) / safe_total_volume, jnp.zeros(3, dtype=vertices.dtype), ) # Compute inertia tensor with covariance approach def compute_tetrahedron_covariance(a, b, c, vol): s = a + b + c return (vol / 20.0) * ( jnp.outer(a, a) + jnp.outer(b, b) + jnp.outer(c, c) + jnp.outer(s, s) ) covariance_matrices = jax.vmap(compute_tetrahedron_covariance)( A, B, C, tetrahedron_volumes ) Σ_origin = jnp.sum(covariance_matrices, axis=0) # Shift to CoM using parallel axis theorem Σ_com = Σ_origin * density - mass * jnp.outer(com_position, com_position) # Convert covariance to inertia tensor I_com = jnp.trace(Σ_com) * jnp.eye(3, dtype=vertices.dtype) - Σ_com I_com = jnp.where( is_valid_volume, I_com, jnp.zeros((3, 3), dtype=vertices.dtype) ) return mass, com_position, I_com @staticmethod def precompute_mesh_moments(vertices: np.ndarray, faces: np.ndarray) -> np.ndarray: """ Precompute volumetric moments from reference mesh geometry. Computes the reference volume, center of mass, and volumetric covariance matrix at the origin using numpy. These 13 scalars are sufficient to analytically reconstruct mass and inertia under any anisotropic scaling, avoiding the need to embed full mesh arrays in JIT-compiled programs. Args: vertices: Mesh vertices (Nx3), should be centered. faces: Triangle face indices (Mx3). Returns: A 13-element array: [V_ref, com_x, com_y, com_z, Σ_00..Σ_22]. """ triangles = vertices[faces.astype(int)] A, B, C = triangles[:, 0], triangles[:, 1], triangles[:, 2] volumes = np.sum(A * np.cross(B, C), axis=1) / 6.0 total_signed = np.sum(volumes) sign = np.sign(total_signed) if abs(total_signed) > 1e-12 else 1.0 volumes = volumes * sign V_ref = np.sum(volumes) if abs(V_ref) < 1e-12: return np.zeros(13, dtype=np.float64) # Center of mass com = np.sum(volumes[:, None] * (A + B + C) / 4.0, axis=0) / V_ref # Volumetric covariance at origin (same formula as compute_mesh_inertia) S = A + B + C cov = (volumes[:, None, None] / 20.0) * ( A[:, :, None] * A[:, None, :] + B[:, :, None] * B[:, None, :] + C[:, :, None] * C[:, None, :] + S[:, :, None] * S[:, None, :] ) Sigma = np.sum(cov, axis=0) return np.concatenate([[V_ref], com, Sigma.flatten()]) @staticmethod def compute_mesh_inertia_from_moments( moments: jtp.Vector, dims: jtp.Vector, density: jtp.Float ) -> tuple[jtp.Float, jtp.Matrix]: """ Compute mass and inertia tensor from precomputed volumetric moments. Uses analytical scaling laws to derive physical properties under anisotropic scaling without requiring the full mesh geometry. Under scaling S = diag(sx, sy, sz): - V' = det(S) * V_ref - com' = S @ com_ref - Σ_origin' = det(S) * S @ Σ_ref @ S Args: moments: Precomputed moments array of length 13. dims: Current anisotropic scale factors [sx, sy, sz]. density: Current material density. Returns: A tuple of (mass, inertia_at_com). """ V_ref = moments[0] com_ref = moments[1:4] Sigma_ref = moments[4:13].reshape(3, 3) det_s = dims[0] * dims[1] * dims[2] S = jnp.diag(dims) mass = density * V_ref * det_s com = dims * com_ref Sigma_scaled = det_s * (S @ Sigma_ref @ S) Sigma_com = density * Sigma_scaled - mass * jnp.outer(com, com) I_com = jnp.trace(Sigma_com) * jnp.eye(3) - Sigma_com is_valid = V_ref > 1e-12 mass = jnp.where(is_valid, mass, 0.0) I_com = jnp.where(is_valid, I_com, jnp.zeros((3, 3))) return mass, I_com @staticmethod def compute_mass_and_inertia( hw_link_metadata: HwLinkMetadata, ) -> tuple[jtp.Float, jtp.Matrix]: """ Compute the mass and inertia of a hardware link based on its metadata. This function calculates the mass and inertia tensor of a hardware link using its shape, dimensions, and density. The computation is performed by using shape-specific methods. Args: hw_link_metadata: Metadata describing the hardware link, including its shape, dimensions, and density. Returns: tuple: A tuple containing: - mass: The computed mass of the hardware link. - inertia: The computed inertia tensor of the hardware link. """ def box(dims, density, _moments) -> tuple[jtp.Float, jtp.Matrix]: lx, ly, lz = dims mass = density * lx * ly * lz inertia = jnp.array( [ [mass * (ly**2 + lz**2) / 12, 0, 0], [0, mass * (lx**2 + lz**2) / 12, 0], [0, 0, mass * (lx**2 + ly**2) / 12], ] ) return mass, inertia def cylinder(dims, density, _moments) -> tuple[jtp.Float, jtp.Matrix]: r, l, _ = dims mass = density * (jnp.pi * r**2 * l) inertia = jnp.array( [ [mass * (3 * r**2 + l**2) / 12, 0, 0], [0, mass * (3 * r**2 + l**2) / 12, 0], [0, 0, mass * (r**2) / 2], ] ) return mass, inertia def sphere(dims, density, _moments) -> tuple[jtp.Float, jtp.Matrix]: r = dims[0] mass = density * (4 / 3 * jnp.pi * r**3) inertia = jnp.eye(3) * (2 / 5 * mass * r**2) return mass, inertia def mesh(dims, density, moments) -> tuple[jtp.Float, jtp.Matrix]: return HwLinkMetadata.compute_mesh_inertia_from_moments( moments, dims, density ) def compute_mass_inertia(shape_idx, dims, density, moments): def unsupported_case(_): return ( jnp.asarray(0.0, dtype=density.dtype), jnp.zeros((3, 3), dtype=density.dtype), ) def supported_case(idx): return jax.lax.switch( idx, (box, cylinder, sphere, mesh), dims, density, moments ) return jax.lax.cond( shape_idx < 0, unsupported_case, supported_case, shape_idx ) masses, inertias = jax.vmap(compute_mass_inertia)( hw_link_metadata.link_shape, hw_link_metadata.geometry, hw_link_metadata.density, hw_link_metadata.mesh_moments, ) return masses, inertias @staticmethod def _convert_scaling_to_3d_vector( link_shapes: jtp.Int, scaling_factors: jtp.Vector ) -> jtp.Vector: """ Convert scaling factors for specific shape dimensions into a 3D scaling vector. Args: link_shapes: The link_shapes of the link (e.g., box, sphere, cylinder, mesh). scaling_factors: The scaling factors for the shape dimensions. Returns: A 3D scaling vector to apply to position vectors. Note: The scaling factors are applied as follows to generate the 3D scale vector: - Box: [lx, ly, lz] - Cylinder: [r, r, l] - Sphere: [r, r, r] - Mesh: [sx, sy, sz] """ # Index mapping for each shape type (link_shapes x 3 dims) # Box: [lx, ly, lz] -> [0, 1, 2] # Cylinder: [r, r, l] -> [0, 0, 1] # Sphere: [r, r, r] -> [0, 0, 0] # Mesh: [sx, sy, sz] -> [0, 1, 2] shape_indices = jnp.array( [ [0, 1, 2], # Box [0, 0, 1], # Cylinder [0, 0, 0], # Sphere [0, 1, 2], # Mesh ] ) # For each link, get the index vector for its shape per_link_indices = shape_indices[link_shapes] # Gather dims per link according to per_link_indices return scaling_factors.dims[per_link_indices.squeeze()] @staticmethod def compute_contact_points( original_contact_params: jtp.Vector, link_shapes: jtp.Vector, original_com_positions: jtp.Vector, updated_com_positions: jtp.Vector, scaling_factors: ScalingFactors, ) -> jtp.Matrix: """ Compute the new contact points based on the original contact parameters and the scaling factors. Args: original_contact_params: The original contact parameters. link_shapes: The shape types of the links (e.g., box, sphere, cylinder). original_com_positions: The original center of mass positions of the links. updated_com_positions: The updated center of mass positions of the links. scaling_factors: The scaling factors for the link dimensions. Returns: The new contact points positions in the parent link frame. """ parent_link_indices = np.array(original_contact_params.body) # Translate the original contact point positions in the origin, so # that we can apply the scaling factors. L_p_Ci = ( original_contact_params.point - original_com_positions[parent_link_indices] ) # Extract the shape types of the parent links. parent_link_shapes = jnp.array(link_shapes[parent_link_indices]) def sphere(parent_idx, L_p_C): r = scaling_factors.dims[parent_idx][0] return L_p_C * r def cylinder(parent_idx, L_p_C): # TODO: Cylinder collisions are not currently supported in JaxSim. return L_p_C def box(parent_idx, L_p_C): lx, ly, lz = scaling_factors.dims[parent_idx] return jnp.hstack( [ L_p_C[0] * lx, L_p_C[1] * ly, L_p_C[2] * lz, ] ) def mesh(parent_idx, L_p_C): sx, sy, sz = scaling_factors.dims[parent_idx] return jnp.hstack( [ L_p_C[0] * sx, L_p_C[1] * sy, L_p_C[2] * sz, ] ) new_positions = jax.vmap( lambda shape_idx, parent_idx, L_p_C: jax.lax.switch( shape_idx, (box, cylinder, sphere, mesh), parent_idx, L_p_C ) )( parent_link_shapes, parent_link_indices, L_p_Ci, ) return new_positions + updated_com_positions[parent_link_indices] @staticmethod def compute_inertia_link(I_com, L_H_G) -> jtp.Matrix: """ Compute the inertia tensor of the link based on its shape and mass. """ L_R_G = L_H_G[:3, :3] return L_R_G @ I_com @ L_R_G.T @staticmethod def apply_scaling( has_joints: bool, hw_metadata: HwLinkMetadata, scaling_factors: ScalingFactors, ) -> HwLinkMetadata: """ Apply scaling to the hardware parameters and return a new HwLinkMetadata object. Args: has_joints: A boolean indicating if the model has joints. hw_metadata: the original HwLinkMetadata object. scaling_factors: the scaling factors to apply. has_joints: whether the model has at least one joint. Returns: A new HwLinkMetadata object with updated parameters. """ scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector( hw_metadata.link_shape, scaling_factors ) # ================================= # Update the kinematics of the link # ================================= # Get the nominal transforms L_H_G = hw_metadata.L_H_G L_H_vis = hw_metadata.L_H_vis L_H_pre_array = hw_metadata.L_H_pre L_H_pre_mask = hw_metadata.L_H_pre_mask # Express the transforms in the G frame G_H_L = jaxsim.math.Transform.inverse(L_H_G) G_H_vis = G_H_L @ L_H_vis G_H_pre_array = ( jax.vmap(lambda L_H_pre: G_H_L @ L_H_pre)(L_H_pre_array) if has_joints else L_H_pre_array ) # Apply the scaling to the position vectors G_H̅_vis = G_H_vis.at[:3, 3].set(scale_vector * G_H_vis[:3, 3]) # Apply scaling to the position vectors in G_H_pre_array based on the mask G_H̅_pre_array = ( G_H_pre_array.at[:, :3, 3].set( jnp.where( L_H_pre_mask[:, None], scale_vector[None, :] * G_H_pre_array[:, :3, 3], G_H_pre_array[:, :3, 3], ) ) if has_joints else G_H_pre_array ) # Get back to the link frame L_H̅_G = L_H_G.at[:3, 3].set(scale_vector * L_H_G[:3, 3]) L_H̅_vis = L_H̅_G @ G_H̅_vis L_H̅_pre_array = ( jax.vmap(lambda G_H̅_pre: L_H̅_G @ G_H̅_pre)(G_H̅_pre_array) if has_joints else G_H̅_pre_array ) # =========================== # Update the shape parameters # =========================== updated_geoms = hw_metadata.geometry * scaling_factors.dims # ============================= # Scale the density of the link # ============================= updated_density = hw_metadata.density * scaling_factors.density # ============================= # Return updated HwLinkMetadata # ============================= return hw_metadata.replace( geometry=updated_geoms, density=updated_density, L_H_G=L_H̅_G, L_H_vis=L_H̅_vis, L_H_pre=L_H̅_pre_array, ) @jax_dataclasses.pytree_dataclass class ScalingFactors(JaxsimDataclass): """ Class storing scaling factors for hardware parameters. Attributes: dims: Scaling factors for shape dimensions. density: Scaling factor for density. """ dims: jtp.Vector density: jtp.Float @dataclasses.dataclass(frozen=True) class ConstraintType: """ Enumeration of all supported constraint types. """ Weld: ClassVar[int] = 0 # TODO: handle Connect constraint # Connect: ClassVar[int] = 1 @jax_dataclasses.pytree_dataclass class ConstraintMap(JaxsimDataclass): """ Class storing the kinematic constraints of a model. """ frame_idxs_1: jtp.Int = dataclasses.field( default_factory=lambda: jnp.array([], dtype=int) ) frame_idxs_2: jtp.Int = dataclasses.field( default_factory=lambda: jnp.array([], dtype=int) ) constraint_types: jtp.Int = dataclasses.field( default_factory=lambda: jnp.array([], dtype=int) ) K_P: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array([], dtype=float) ) K_D: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array([], dtype=float) ) # Precomputed parent link indices for each constraint pair parent_link_idxs_1: jtp.Int = dataclasses.field( default_factory=lambda: jnp.array([], dtype=int) ) parent_link_idxs_2: jtp.Int = dataclasses.field( default_factory=lambda: jnp.array([], dtype=int) ) def add_constraint( self, model: jaxsim.api.model.JaxSimModel, frame_idx_1: int, frame_idx_2: int, constraint_type: int, K_P: float | None = None, K_D: float | None = None, ) -> ConstraintMap: """ Add a constraint to the constraint map. Args: model: The model for which the constraints are added. frame_idx_1: The index of the first frame. frame_idx_2: The index of the second frame. constraint_type: The type of constraint. K_P: The proportional gain for Baumgarte stabilization (default: 1000). K_D: The derivative gain for Baumgarte stabilization (default: 2 * sqrt(K_P)). Returns: A new ConstraintMap instance with the added constraint. Note: Since this method returns a new instance of ConstraintMap with the new constraint, it will trigger recompilations in JIT-compiled functions. """ # Set default values for Baumgarte coefficients if not provided if K_P is None: K_P = jnp.array([1000.0]) if K_D is None: K_D = 2 * jnp.sqrt(K_P) # Create new arrays with the input elements appended new_frame_idxs_1 = jnp.append(self.frame_idxs_1, frame_idx_1) new_frame_idxs_2 = jnp.append(self.frame_idxs_2, frame_idx_2) new_constraint_types = jnp.append(self.constraint_types, constraint_type) new_K_P = jnp.append(self.K_P, K_P) new_K_D = jnp.append(self.K_D, K_D) # Compute parent link indices (now always available since model is required) parent_link_idx_1 = jaxsim.api.frame.idx_of_parent_link( model, frame_index=frame_idx_1 ) parent_link_idx_2 = jaxsim.api.frame.idx_of_parent_link( model, frame_index=frame_idx_2 ) new_parent_link_idxs_1 = jnp.append(self.parent_link_idxs_1, parent_link_idx_1) new_parent_link_idxs_2 = jnp.append(self.parent_link_idxs_2, parent_link_idx_2) # Return a new ConstraintMap object with updated attributes return ConstraintMap( frame_idxs_1=new_frame_idxs_1, frame_idxs_2=new_frame_idxs_2, constraint_types=new_constraint_types, K_P=new_K_P, K_D=new_K_D, parent_link_idxs_1=new_parent_link_idxs_1, parent_link_idxs_2=new_parent_link_idxs_2, ) ================================================ FILE: src/jaxsim/api/link.py ================================================ import functools from collections.abc import Sequence import jax import jax.numpy as jnp import jax.scipy.linalg import numpy as np import jaxsim.api as js import jaxsim.rbda import jaxsim.typing as jtp from jaxsim import exceptions from jaxsim.math import Adjoint from .common import VelRepr # ======================= # Index-related functions # ======================= @functools.partial(jax.jit, static_argnames="link_name") def name_to_idx(model: js.model.JaxSimModel, *, link_name: str) -> jtp.Int: """ Convert the name of a link to its index. Args: model: The model to consider. link_name: The name of the link. Returns: The index of the link. """ if link_name not in model.link_names(): raise ValueError(f"Link '{link_name}' not found in the model.") return ( jnp.array(model.kin_dyn_parameters.link_names.index(link_name)) .astype(int) .squeeze() ) def idx_to_name(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> str: """ Convert the index of a link to its name. Args: model: The model to consider. link_index: The index of the link. Returns: The name of the link. """ exceptions.raise_value_error_if( condition=link_index < 0, msg="Invalid link index '{idx}'", idx=link_index, ) return model.kin_dyn_parameters.link_names[link_index] @functools.partial(jax.jit, static_argnames="link_names") def names_to_idxs( model: js.model.JaxSimModel, *, link_names: Sequence[str] ) -> jax.Array: """ Convert a sequence of link names to their corresponding indices. Args: model: The model to consider. link_names: The names of the links. Returns: The indices of the links. """ return jnp.array( [name_to_idx(model=model, link_name=name) for name in link_names], ).astype(int) def idxs_to_names( model: js.model.JaxSimModel, *, link_indices: Sequence[jtp.IntLike] | jtp.VectorLike ) -> tuple[str, ...]: """ Convert a sequence of link indices to their corresponding names. Args: model: The model to consider. link_indices: The indices of the links. Returns: The names of the links. """ return tuple(np.array(model.kin_dyn_parameters.link_names)[list(link_indices)]) # ========= # Link APIs # ========= @jax.jit def mass(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Float: """ Return the mass of the link. Args: model: The model to consider. link_index: The index of the link. Returns: The mass of the link. """ exceptions.raise_value_error_if( condition=jnp.array( [link_index < 0, link_index >= model.number_of_links()] ).any(), msg="Invalid link index '{idx}'", idx=link_index, ) return model.kin_dyn_parameters.link_parameters.mass[link_index].astype(float) @jax.jit def spatial_inertia( model: js.model.JaxSimModel, *, link_index: jtp.IntLike ) -> jtp.Matrix: r""" Compute the 6D spatial inertial of the link. Args: model: The model to consider. link_index: The index of the link. Returns: The :math:`6 \times 6` matrix representing the spatial inertia of the link expressed in the link frame (body-fixed representation). """ exceptions.raise_value_error_if( condition=jnp.array( [link_index < 0, link_index >= model.number_of_links()] ).any(), msg="Invalid link index '{idx}'", idx=link_index, ) link_parameters = jax.tree.map( lambda l: l[link_index], model.kin_dyn_parameters.link_parameters ) return js.kin_dyn_parameters.LinkParameters.spatial_inertia(link_parameters) @jax.jit def transform( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_index: jtp.IntLike, ) -> jtp.Matrix: """ Compute the SE(3) transform from the world frame to the link frame. Args: model: The model to consider. data: The data of the considered model. link_index: The index of the link. Returns: The 4x4 matrix representing the transform. """ exceptions.raise_value_error_if( condition=jnp.array( [link_index < 0, link_index >= model.number_of_links()] ).any(), msg="Invalid link index '{idx}'", idx=link_index, ) return data._link_transforms[link_index] @jax.jit def com_position( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_index: jtp.IntLike, in_link_frame: jtp.BoolLike = True, ) -> jtp.Vector: """ Compute the position of the center of mass of the link. Args: model: The model to consider. data: The data of the considered model. link_index: The index of the link. in_link_frame: Whether to return the position in the link frame or in the world frame. Returns: The 3D position of the center of mass of the link. """ from jaxsim.math.inertia import Inertia _, L_p_CoM, _ = Inertia.to_params( M=spatial_inertia(model=model, link_index=link_index) ) def com_in_link_frame(): return L_p_CoM.squeeze() def com_in_inertial_frame(): W_H_L = transform(link_index=link_index, model=model, data=data) W_p̃_CoM = W_H_L @ jnp.hstack([L_p_CoM.squeeze(), 1]) return W_p̃_CoM[0:3].squeeze() return jax.lax.select( pred=in_link_frame, on_true=com_in_link_frame(), on_false=com_in_inertial_frame(), ) @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) def jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_index: jtp.IntLike, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: r""" Compute the free-floating jacobian of the link. Args: model: The model to consider. data: The data of the considered model. link_index: The index of the link. output_vel_repr: The output velocity representation of the free-floating jacobian. Returns: The :math:`6 \times (6+n)` free-floating jacobian of the link. Note: The input representation of the free-floating jacobian is the active velocity representation. """ exceptions.raise_value_error_if( condition=jnp.array( [link_index < 0, link_index >= model.number_of_links()] ).any(), msg="Invalid link index '{idx}'", idx=link_index, ) output_vel_repr = ( output_vel_repr if output_vel_repr is not None else data.velocity_representation ) # Compute the doubly-left free-floating full jacobian. B_J_full_WX_B, B_H_Li = jaxsim.rbda.jacobian_full_doubly_left( model=model, joint_positions=data.joint_positions, ) # Compute the actual doubly-left free-floating jacobian of the link. κb = model.kin_dyn_parameters.support_body_array_bool[link_index] B_J_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J_full_WX_B # Adjust the input representation such that `J_WL_I @ I_ν`. match data.velocity_representation: case VelRepr.Inertial: W_H_B = data._base_transform B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841 B_X_W, jnp.eye(model.dofs()) ) case VelRepr.Body: B_J_WL_I = B_J_WL_B case VelRepr.Mixed: W_R_B = jaxsim.math.Quaternion.to_dcm(data.base_orientation) BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841 B_X_BW, jnp.eye(model.dofs()) ) case _: raise ValueError(data.velocity_representation) B_H_L = B_H_Li[link_index] # Adjust the output representation such that `O_v_WL_I = O_J_WL_I @ I_ν`. match output_vel_repr: case VelRepr.Inertial: W_H_B = data._base_transform W_X_B = Adjoint.from_transform(transform=W_H_B) O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I # noqa: F841 case VelRepr.Body: L_X_B = Adjoint.from_transform(transform=B_H_L, inverse=True) L_J_WL_I = L_X_B @ B_J_WL_I O_J_WL_I = L_J_WL_I case VelRepr.Mixed: W_H_B = data._base_transform W_H_L = W_H_B @ B_H_L LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3)) LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L) LW_X_B = Adjoint.from_transform(transform=LW_H_B) LW_J_WL_I = LW_X_B @ B_J_WL_I O_J_WL_I = LW_J_WL_I case _: raise ValueError(output_vel_repr) return O_J_WL_I @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) def velocity( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_index: jtp.IntLike, output_vel_repr: VelRepr | None = None, ) -> jtp.Vector: """ Compute the 6D velocity of the link. Args: model: The model to consider. data: The data of the considered model. link_index: The index of the link. output_vel_repr: The output velocity representation of the link velocity. Returns: The 6D velocity of the link in the specified velocity representation. """ exceptions.raise_value_error_if( condition=jnp.array( [link_index < 0, link_index >= model.number_of_links()] ).any(), msg="Invalid link index '{idx}'", idx=link_index, ) output_vel_repr = ( output_vel_repr if output_vel_repr is not None else data.velocity_representation ) # Get the link jacobian having I as input representation (taken from data) # and O as output representation, specified by the user (or taken from data). O_J_WL_I = jacobian( model=model, data=data, link_index=link_index, output_vel_repr=output_vel_repr, ) # Get the generalized velocity in the input velocity representation. I_ν = data.generalized_velocity # Compute the link velocity in the output velocity representation. return O_J_WL_I @ I_ν @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) def jacobian_derivative( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_index: jtp.IntLike, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: r""" Compute the derivative of the free-floating jacobian of the link. Args: model: The model to consider. data: The data of the considered model. link_index: The index of the link. output_vel_repr: The output velocity representation of the free-floating jacobian derivative. Returns: The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the link. Note: The input representation of the free-floating jacobian derivative is the active velocity representation. """ exceptions.raise_value_error_if( condition=jnp.array( [link_index < 0, link_index >= model.number_of_links()] ).any(), msg="Invalid link index '{idx}'", idx=link_index, ) output_vel_repr = ( output_vel_repr if output_vel_repr is not None else data.velocity_representation ) O_J̇_WL_I = js.model.generalized_free_floating_jacobian_derivative( model=model, data=data, output_vel_repr=output_vel_repr )[link_index] return O_J̇_WL_I @jax.jit def bias_acceleration( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_index: jtp.IntLike, ) -> jtp.Vector: """ Compute the bias acceleration of the link. Args: model: The model to consider. data: The data of the considered model. link_index: The index of the link. Returns: The 6D bias acceleration of the link. """ exceptions.raise_value_error_if( condition=jnp.array( [link_index < 0, link_index >= model.number_of_links()] ).any(), msg="Invalid link index '{idx}'", idx=link_index, ) # Compute the bias acceleration of all links in the active representation. O_v̇_WL = js.model.link_bias_accelerations(model=model, data=data)[link_index] return O_v̇_WL ================================================ FILE: src/jaxsim/api/model.py ================================================ from __future__ import annotations import copy import dataclasses import enum import functools import pathlib from collections.abc import Sequence import jax import jax.numpy as jnp import jax_dataclasses import numpy as np import rod from jax_dataclasses import Static from rod.urdf.exporter import UrdfExporter import jaxsim.api as js import jaxsim.terrain import jaxsim.typing as jtp from jaxsim import logging from jaxsim.api.kin_dyn_parameters import ( HwLinkMetadata, KinDynParameters, LinkParameters, LinkParametrizableShape, ScalingFactors, ) from jaxsim.math import Adjoint, Cross, Skew from jaxsim.parsers.descriptions import ModelDescription from jaxsim.parsers.descriptions.joint import JointDescription from jaxsim.parsers.descriptions.link import LinkDescription from jaxsim.parsers.rod.utils import prepare_mesh_for_parametrization from jaxsim.utils import JaxsimDataclass, Mutability, wrappers from .common import VelRepr class IntegratorType(enum.IntEnum): """The integrators available for the simulation.""" SemiImplicitEuler = enum.auto() RungeKutta4 = enum.auto() RungeKutta4Fast = enum.auto() @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) class JaxSimModel(JaxsimDataclass): """ The JaxSim model defining the kinematics and dynamics of a robot. """ model_name: Static[str] time_step: float = dataclasses.field( default=0.001, ) terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field( default_factory=jaxsim.terrain.FlatTerrain.build, repr=False ) gravity: Static[float] = -jaxsim.math.STANDARD_GRAVITY contact_model: Static[jaxsim.rbda.contacts.ContactModel | None] = dataclasses.field( default=None, repr=False ) contact_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field( default=None, repr=False ) actuation_params: Static[jaxsim.rbda.actuation.ActuationParams] = dataclasses.field( default=None, repr=False ) kin_dyn_parameters: js.kin_dyn_parameters.KinDynParameters | None = ( dataclasses.field(default=None, repr=False) ) integrator: Static[IntegratorType] = dataclasses.field( default=IntegratorType.SemiImplicitEuler, repr=False ) built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field( default=None, repr=False ) _description: Static[wrappers.HashlessObject[ModelDescription | None]] = ( dataclasses.field(default=None, repr=False) ) @property def description(self) -> ModelDescription: """ Return the model description. """ return self._description.get() def __eq__(self, other: JaxSimModel) -> bool: if not isinstance(other, JaxSimModel): return False if self.model_name != other.model_name: return False if self.time_step != other.time_step: return False if self.kin_dyn_parameters != other.kin_dyn_parameters: return False return True def __hash__(self) -> int: return hash( ( hash(self.model_name), hash(self.time_step), hash(self.kin_dyn_parameters), hash(self.contact_model), ) ) # ======================== # Initialization and state # ======================== @classmethod def build_from_model_description( cls, model_description: str | pathlib.Path | rod.Model, *, model_name: str | None = None, time_step: jtp.FloatLike | None = None, terrain: jaxsim.terrain.Terrain | None = None, contact_model: jaxsim.rbda.contacts.ContactModel | None = None, contact_params: jaxsim.rbda.contacts.ContactsParams | None = None, actuation_params: jaxsim.rbda.actuation.ActuationParams | None = None, integrator: IntegratorType | None = None, is_urdf: bool | None = None, considered_joints: Sequence[str] | None = None, gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, constraints: jaxsim.rbda.kinematic_constraints.ConstraintMap | None = None, parametrized_links: tuple[str, ...] | None = None, ) -> JaxSimModel: """ Build a Model object from a model description. Args: model_description: A path to an SDF/URDF file, a string containing its content, or a pre-parsed/pre-built rod model. model_name: The name of the model. If not specified, it is read from the description. time_step: The default time step to consider for the simulation. It can be manually overridden in the function that steps the simulation. terrain: The terrain to consider (the default is a flat infinite plane). contact_model: The contact model to consider. If not specified, a soft contacts model is used. contact_params: The parameters of the contact model. actuation_params: The parameters of the actuation model. integrator: The integrator to use for the simulation. is_urdf: The optional flag to force the model description to be parsed as a URDF. This is usually automatically inferred. considered_joints: The list of joints to consider. If None, all joints are considered. gravity: The gravity constant. Normally passed as a positive value. constraints: An object of type ConstraintMap containing the kinematic constraints to consider. If None, no constraints are considered. Note that constraints can be used only with RelaxedRigidContacts. parametrized_links: The optional list of links to be parametrized. If None, all links are parametrized. Returns: The built Model object. """ import jaxsim.parsers.rod # Parse the input resource (either a path to file or a string with the URDF/SDF) # and build the -intermediate- model description. intermediate_description = jaxsim.parsers.rod.build_model_description( model_description=model_description, is_urdf=is_urdf ) # Lump links together if not all joints are considered. # Note: this procedure assigns a zero position to all joints not considered. if considered_joints is not None: intermediate_description = intermediate_description.reduce( considered_joints=considered_joints ) # Build the model. model = cls.build( model_description=intermediate_description, model_name=model_name, time_step=time_step, terrain=terrain, contact_model=contact_model, actuation_params=actuation_params, contact_params=contact_params, integrator=integrator, gravity=-gravity, constraints=constraints, parametrized_links=parametrized_links, ) # Store the origin of the model, in case downstream logic needs it. with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): model.built_from = model_description # Compute the hw parametrization metadata of the model # TODO: move the building of the metadata to KinDynParameters.build() # and use the model_description instead of model.built_from. with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): model.kin_dyn_parameters.hw_link_metadata = model.compute_hw_link_metadata( parametrized_links=parametrized_links ) return model @classmethod def build( cls, model_description: ModelDescription, *, model_name: str | None = None, time_step: jtp.FloatLike | None = None, terrain: jaxsim.terrain.Terrain | None = None, contact_model: jaxsim.rbda.contacts.ContactModel | None = None, contact_params: jaxsim.rbda.contacts.ContactsParams | None = None, actuation_params: jaxsim.rbda.actuation.ActuationParams | None = None, integrator: IntegratorType | None = None, gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, constraints: jaxsim.rbda.kinematic_constraints.ConstraintMap | None = None, parametrized_links: tuple[str, ...] | None = None, ) -> JaxSimModel: """ Build a Model object from an intermediate model description. Args: model_description: The intermediate model description defining the kinematics and dynamics of the model. model_name: The name of the model. If not specified, it is read from the description. time_step: The default time step to consider for the simulation. It can be manually overridden in the function that steps the simulation. terrain: The terrain to consider (the default is a flat infinite plane). The optional name of the model overriding the physics model name. contact_model: The contact model to consider. If not specified, a soft contact model is used. contact_params: The parameters of the contact model. actuation_params: The parameters of the actuation model. integrator: The integrator to use for the simulation. gravity: The gravity constant. constraints: An object of type ConstraintMap containing the kinematic constraints to consider. If None, no constraints are considered. parametrized_links: The optional list of links to be parametrized. If None, all links are parametrized. Returns: The built Model object. """ # Set the model name (if not provided, use the one from the model description). model_name = model_name if model_name is not None else model_description.name # Consider the default terrain (a flat infinite plane) if not specified. terrain = ( terrain if terrain is not None else JaxSimModel.__dataclass_fields__["terrain"].default_factory() ) # Consider the default time step if not specified. time_step = ( time_step if time_step is not None else JaxSimModel.__dataclass_fields__["time_step"].default ) # Create the default contact model. # It will be populated with an initial estimation of good parameters. # While these might not be the best, they are a good starting point. contact_model = ( contact_model if contact_model is not None else jaxsim.rbda.contacts.SoftContacts.build() ) if contact_params is None: contact_params = contact_model._parameters_class() if actuation_params is None: actuation_params = jaxsim.rbda.actuation.ActuationParams() # Consider the default integrator if not specified. integrator = ( integrator if integrator is not None else JaxSimModel.__dataclass_fields__["integrator"].default ) # Build the model. model = cls( model_name=model_name, kin_dyn_parameters=js.kin_dyn_parameters.KinDynParameters.build( model_description=model_description, constraints=constraints ), time_step=time_step, terrain=terrain, contact_model=contact_model, contact_params=contact_params, actuation_params=actuation_params, integrator=integrator, gravity=gravity, # The following is wrapped as hashless since it's a static argument, and we # don't want to trigger recompilation if it changes. All relevant parameters # needed to compute kinematics and dynamics quantities are stored in the # kin_dyn_parameters attribute. _description=wrappers.HashlessObject(obj=model_description), ) return model def compute_hw_link_metadata( self, parametrized_links: tuple[str, ...] | None = None ) -> HwLinkMetadata: """ Compute the parametric metadata of the links in the model. Args: parametrized_links: An optional tuple of link names to be parametrized. If None, all links will be parametrized. Returns: An instance of HwLinkMetadata containing the metadata of all links. """ model_description = self.description # Get ordered links and joints from the model description ordered_links: list[LinkDescription] = sorted( list(model_description.links_dict.values()), key=lambda l: l.index, ) ordered_joints: list[JointDescription] = sorted( list(model_description.joints_dict.values()), key=lambda j: j.index, ) # Ensure the model was built from a valid source rod_model = None match self.built_from: case str() | pathlib.Path(): rod_model = rod.Sdf.load(sdf=self.built_from).models()[0] assert rod_model.name == self.name() case rod.Model(): rod_model = self.built_from case _: logging.debug( f"Invalid type for model.built_from ({type(self.built_from)})." "Skipping for hardware parametrization." ) return HwLinkMetadata.empty() # Use URDF frame convention for consistent pose representation rod_model.switch_frame_convention( frame_convention=rod.FrameConvention.Urdf, explicit_frames=True ) rod_links_dict = {link.name: link for link in rod_model.links()} # Initialize lists to collect metadata for all links shapes = [] geoms = [] densities = [] L_H_Gs = [] L_H_vises = [] L_H_pre_masks = [] L_H_pre = [] mesh_moments_list = [] mesh_vertices = [] mesh_faces = [] mesh_offsets = [] mesh_uris = [] # Process each link, only parametrizing those in parametrized_links if provided for link_description in ordered_links: link_name = link_description.name if parametrized_links is not None and link_name not in parametrized_links: # Mark as unsupported for non-parametrized links shapes.append(LinkParametrizableShape.Unsupported) geoms.append([0, 0, 0]) densities.append(0.0) L_H_Gs.append(jnp.eye(4)) L_H_vises.append(jnp.eye(4)) L_H_pre_masks.append([0] * self.number_of_joints()) L_H_pre.append([jnp.eye(4)] * self.number_of_joints()) mesh_vertices.append(None) mesh_faces.append(None) mesh_offsets.append(None) mesh_uris.append(None) mesh_moments_list.append(np.zeros(13)) continue rod_link = rod_links_dict.get(link_name) link_index = int(js.link.name_to_idx(model=self, link_name=link_name)) # Get child joints for the link child_joints_indices = [ js.joint.name_to_idx(model=self, joint_name=j.name) for j in ordered_joints if j.parent.name == link_name ] # Skip unsupported links if not jnp.allclose( self.kin_dyn_parameters.joint_model.suc_H_i[link_index], jnp.eye(4), **(dict(atol=1e-6) if not jax.config.jax_enable_x64 else {}), ): logging.debug( f"Skipping link '{link_name}' for hardware parametrization due to unsupported suc_H_link." ) rod_link = None # Compute density and dimensions mass = float(self.kin_dyn_parameters.link_parameters.mass[link_index]) # Find the first supported visual supported_visual = ( next( ( v for v in rod_link.visuals() if isinstance( v.geometry.geometry(), (rod.Box, rod.Sphere, rod.Cylinder, rod.Mesh), ) ), None, ) if rod_link else None ) geometry = ( supported_visual.geometry.geometry() if supported_visual else None ) if isinstance(geometry, rod.Box): lx, ly, lz = geometry.size density = mass / (lx * ly * lz) geom = [lx, ly, lz] shape = LinkParametrizableShape.Box mesh_vertices.append(None) mesh_faces.append(None) mesh_offsets.append(None) mesh_uris.append(None) mesh_moments_list.append(np.zeros(13)) elif isinstance(geometry, rod.Sphere): r = geometry.radius density = mass / (4 / 3 * jnp.pi * r**3) geom = [r, 0, 0] shape = LinkParametrizableShape.Sphere mesh_vertices.append(None) mesh_faces.append(None) mesh_offsets.append(None) mesh_uris.append(None) mesh_moments_list.append(np.zeros(13)) elif isinstance(geometry, rod.Cylinder): r, l = geometry.radius, geometry.length density = mass / (jnp.pi * r**2 * l) geom = [r, l, 0] shape = LinkParametrizableShape.Cylinder mesh_vertices.append(None) mesh_faces.append(None) mesh_offsets.append(None) mesh_uris.append(None) mesh_moments_list.append(np.zeros(13)) elif isinstance(geometry, rod.Mesh): # Load and prepare mesh for parametric scaling try: mesh_data = prepare_mesh_for_parametrization( mesh_uri=geometry.uri, scale=geometry.scale, ) density = ( mass / mesh_data["volume"] if mesh_data["volume"] > 0 else 0.0 ) # For meshes, store cumulative scale factors (initially 1.0) in geometry # instead of bounding box extents. This allows proper multiplicative scaling. geom = [1.0, 1.0, 1.0] shape = LinkParametrizableShape.Mesh # Store mesh data mesh_vertices.append(mesh_data["vertices"]) mesh_faces.append(mesh_data["faces"]) mesh_offsets.append(mesh_data["offset"]) mesh_uris.append(mesh_data["uri"]) # Precompute volumetric moments for JIT-friendly inertia computation mesh_moments_list.append( HwLinkMetadata.precompute_mesh_moments( mesh_data["vertices"], mesh_data["faces"] ) ) logging.info( f"Loaded mesh for link '{link_name}': " f"{len(mesh_data['vertices'])} vertices, " f"{len(mesh_data['faces'])} faces, " ) except Exception as e: logging.warning( f"Failed to load mesh for link '{link_name}': {e}. " f"Marking as unsupported." ) density = 0.0 geom = [0, 0, 0] shape = LinkParametrizableShape.Unsupported mesh_vertices.append(None) mesh_faces.append(None) mesh_offsets.append(None) mesh_uris.append(None) mesh_moments_list.append(np.zeros(13)) else: logging.debug( f"Skipping link '{link_name}' for hardware parametrization due to unsupported geometry." ) density = 0.0 geom = [0, 0, 0] shape = LinkParametrizableShape.Unsupported mesh_vertices.append(None) mesh_faces.append(None) mesh_offsets.append(None) mesh_uris.append(None) mesh_moments_list.append(np.zeros(13)) inertial_pose = ( rod_link.inertial.pose.transform() if rod_link else jnp.eye(4) ) visual_pose = ( supported_visual.pose.transform() if supported_visual else jnp.eye(4) ) l_h_pre_mask = [ int(joint_index in child_joints_indices) for joint_index in range(self.number_of_joints()) ] l_h_pre = [ ( self.kin_dyn_parameters.joint_model.λ_H_pre[joint_index + 1] if joint_index in child_joints_indices else jnp.eye(4) ) for joint_index in range(self.number_of_joints()) ] shapes.append(shape) geoms.append(geom) densities.append(density) L_H_Gs.append(inertial_pose) L_H_vises.append(visual_pose) L_H_pre_masks.append(l_h_pre_mask) L_H_pre.append(l_h_pre) if np.all(np.array(shapes) == LinkParametrizableShape.Unsupported): logging.debug( "All links were skipped for hardware parametrization. Returning empty metadata." ) return HwLinkMetadata.empty() # Stack collected data into JAX arrays # Handle L_H_pre specially: ensure shape (n_links, n_joints, 4, 4) even when n_joints=0 L_H_pre_array = jnp.array(L_H_pre, dtype=float) if self.number_of_joints() == 0: # Reshape from (n_links, 0) to (n_links, 0, 4, 4) n_links = len(L_H_pre) L_H_pre_array = L_H_pre_array.reshape(n_links, 0, 4, 4) return HwLinkMetadata( link_shape=jnp.array(shapes, dtype=int), geometry=jnp.array(geoms, dtype=float), density=jnp.array(densities, dtype=float), L_H_G=jnp.array(L_H_Gs, dtype=float), L_H_vis=jnp.array(L_H_vises, dtype=float), L_H_pre_mask=jnp.array(L_H_pre_masks, dtype=bool), L_H_pre=L_H_pre_array, mesh_moments=jnp.array(np.stack(mesh_moments_list), dtype=float), mesh_vertices=( tuple( wrappers.HashedNumpyArray(array=v) if v is not None else None for v in mesh_vertices ) if any(v is not None for v in mesh_vertices) else None ), mesh_faces=( tuple( wrappers.HashedNumpyArray(array=f) if f is not None else None for f in mesh_faces ) if any(f is not None for f in mesh_faces) else None ), mesh_offset=( tuple( wrappers.HashedNumpyArray(array=o) if o is not None else None for o in mesh_offsets ) if any(o is not None for o in mesh_offsets) else None ), mesh_uri=( tuple(mesh_uris) if any(u is not None for u in mesh_uris) else None ), ) def export_updated_model(self) -> str: """ Export the JaxSim model to URDF with the current hardware parameters. Returns: The URDF string of the updated model. Note: This method is not meant to be used in JIT-compiled functions. """ if isinstance(jnp.zeros(0), jax.core.Tracer): raise RuntimeError("This method cannot be used in JIT-compiled functions") # Ensure `built_from` is a ROD model and create `rod_model_output` if isinstance(self.built_from, rod.Model): rod_model_output = copy.deepcopy(self.built_from) elif isinstance(self.built_from, (str, pathlib.Path)): rod_model_output = rod.Sdf.load(sdf=self.built_from).models()[0] else: raise ValueError( "The JaxSim model must be built from a valid ROD model source" ) # Switch to URDF frame convention for easier mapping rod_model_output.switch_frame_convention( frame_convention=rod.FrameConvention.Urdf, explicit_frames=True, attach_frames_to_links=True, ) # Get links and joints from the ROD model links_dict = {link.name: link for link in rod_model_output.links()} joints_dict = {joint.name: joint for joint in rod_model_output.joints()} # Iterate over the hardware metadata to update the ROD model hw_metadata = self.kin_dyn_parameters.hw_link_metadata reduced_link_names = set(self.link_names()) reduced_joint_names = set(self.joint_names()) unit_scale = np.ones(3, dtype=float) link_scale_factors: dict[str, np.ndarray] = {} def collect_link_elements(link) -> list: elements_to_update_raw = (link.visual, link.collision) elements_to_update = [] for entry in elements_to_update_raw: if entry is None: continue if isinstance(entry, (list, tuple)): elements_to_update.extend(e for e in entry if e is not None) else: elements_to_update.append(entry) return elements_to_update def scale_pose_translation(element, scale_vector): if getattr(element, "pose", None) is None: return transform = np.array(element.pose.transform(), dtype=float) transform[0:3, 3] = scale_vector * transform[0:3, 3] element.pose = rod.Pose.from_transform( transform=transform, relative_to=element.pose.relative_to, ) def scale_link_elements( elements_to_update: list, scale_vector: np.ndarray, *, mesh_pose: rod.Pose | None = None, mesh_shape_link: bool = False, ) -> None: for element in elements_to_update: if ( element is None or not hasattr(element, "geometry") or element.geometry is None ): continue geometry = element.geometry if getattr(geometry, "box", None) is not None: current_size = np.array(geometry.box.size, dtype=float) geometry.box.size = tuple( float(v) for v in (current_size * scale_vector).tolist() ) scale_pose_translation(element, scale_vector) elif getattr(geometry, "sphere", None) is not None: geometry.sphere.radius = float( float(geometry.sphere.radius) * float(scale_vector[0]) ) scale_pose_translation(element, scale_vector) elif getattr(geometry, "cylinder", None) is not None: geometry.cylinder.radius = float( float(geometry.cylinder.radius) * float(scale_vector[0]) ) geometry.cylinder.length = float( float(geometry.cylinder.length) * float(scale_vector[2]) ) scale_pose_translation(element, scale_vector) elif getattr(geometry, "mesh", None) is not None: base_scale = ( np.array(geometry.mesh.scale, dtype=float) if geometry.mesh.scale is not None else unit_scale ) geometry.mesh.scale = tuple( float(v) for v in (base_scale * scale_vector).tolist() ) # Mesh-parametrized reduced links use metadata to preserve # the main visual placement in the exported URDF. if mesh_shape_link and mesh_pose is not None: element.pose = mesh_pose else: scale_pose_translation(element, scale_vector) for link_index, link_name in enumerate(self.link_names()): if link_name not in links_dict: continue # Skip links with unsupported shapes shape = hw_metadata.link_shape[link_index] if shape == LinkParametrizableShape.Unsupported: logging.debug(f"Skipping link '{link_name}' with unsupported shape") continue # Update mass and inertia mass = float(self.kin_dyn_parameters.link_parameters.mass[link_index]) center_of_mass = np.array( self.kin_dyn_parameters.link_parameters.center_of_mass[link_index] ) inertia_tensor = LinkParameters.unflatten_inertia_tensor( self.kin_dyn_parameters.link_parameters.inertia_elements[link_index] ) links_dict[link_name].inertial.mass = mass L_H_COM = np.eye(4) L_H_COM[0:3, 3] = center_of_mass links_dict[link_name].inertial.pose = rod.Pose.from_transform( transform=L_H_COM, relative_to=links_dict[link_name].inertial.pose.relative_to, ) links_dict[link_name].inertial.inertia = rod.Inertia.from_inertia_tensor( inertia_tensor=inertia_tensor, validate=True ) dims = np.array(hw_metadata.geometry[link_index], dtype=float) elements_to_update = collect_link_elements(links_dict[link_name]) def find_reference_geometry(attr: str, elements: list = elements_to_update): for element in elements: if ( element is None or not hasattr(element, "geometry") or element.geometry is None ): continue geometry = getattr(element.geometry, attr, None) if geometry is not None: return geometry return None if shape == LinkParametrizableShape.Mesh: scale_vector = dims elif shape == LinkParametrizableShape.Box: ref_box = find_reference_geometry("box") if ref_box is None: scale_vector = unit_scale else: base_size = np.array(ref_box.size, dtype=float) scale_vector = np.divide( dims, base_size, out=np.ones(3, dtype=float), where=np.abs(base_size) > 1e-12, ) elif shape == LinkParametrizableShape.Sphere: ref_sphere = find_reference_geometry("sphere") base_radius = ( float(ref_sphere.radius) if ref_sphere is not None else 1.0 ) s = float(dims[0]) / base_radius if abs(base_radius) > 1e-12 else 1.0 scale_vector = np.array([s, s, s], dtype=float) elif shape == LinkParametrizableShape.Cylinder: ref_cylinder = find_reference_geometry("cylinder") base_radius = ( float(ref_cylinder.radius) if ref_cylinder is not None else 1.0 ) base_length = ( float(ref_cylinder.length) if ref_cylinder is not None else 1.0 ) s_radius = ( float(dims[0]) / base_radius if abs(base_radius) > 1e-12 else 1.0 ) s_length = ( float(dims[1]) / base_length if abs(base_length) > 1e-12 else 1.0 ) scale_vector = np.array([s_radius, s_radius, s_length], dtype=float) else: scale_vector = unit_scale link_scale_factors[link_name] = np.array(scale_vector, dtype=float) element_pose = rod.Pose.from_transform( transform=np.array(hw_metadata.L_H_vis[link_index]), relative_to=link_name, ) scale_link_elements( elements_to_update=elements_to_update, scale_vector=scale_vector, mesh_pose=element_pose, mesh_shape_link=(shape == LinkParametrizableShape.Mesh), ) # Update joint poses for joint_index in range(self.number_of_joints()): if hw_metadata.L_H_pre_mask[link_index, joint_index]: joint_name = js.joint.idx_to_name( model=self, joint_index=joint_index ) if joint_name in joints_dict: joints_dict[joint_name].pose = rod.Pose.from_transform( transform=np.array( hw_metadata.L_H_pre[link_index, joint_index] ), relative_to=link_name, ) # Propagate link scaling to descendants connected through fixed joints. # These links are typically reduced away in the JaxSim model (e.g. feet # attached to ankles) but still exist in the exported URDF tree. updated = True while updated: updated = False for joint in joints_dict.values(): if joint.type != "fixed": continue parent_scale = link_scale_factors.get(joint.parent, None) if parent_scale is None or joint.child in link_scale_factors: continue link_scale_factors[joint.child] = np.array(parent_scale, dtype=float) updated = True # Scale fixed-joint offsets that are not part of the reduced joint set. for joint_name, joint in joints_dict.items(): if joint.type != "fixed" or joint_name in reduced_joint_names: continue parent_scale = link_scale_factors.get(joint.parent, unit_scale) if np.allclose(parent_scale, unit_scale): continue if joint.pose is None: continue transform = np.array(joint.pose.transform(), dtype=float) transform[0:3, 3] = parent_scale * transform[0:3, 3] joint.pose = rod.Pose.from_transform( transform=transform, relative_to=joint.pose.relative_to, ) # Apply inherited scaling to non-reduced links (typically descendants # connected via fixed joints). for link_name, scale_vector in link_scale_factors.items(): if link_name in reduced_link_names: continue if np.allclose(scale_vector, unit_scale): continue if link_name not in links_dict: continue scale_link_elements( elements_to_update=collect_link_elements(links_dict[link_name]), scale_vector=scale_vector, ) # Restore continuous joint types for joints with infinite limits # to ensure valid URDF export (continuous joints should not have limits). # Continuous joints are internally represented as revolute with infinite # limits, but must be exported as type="continuous" for valid URDF. for joint in joints_dict.values(): # Skip if not a revolute joint with axis and limits if not ( joint.type == "revolute" and joint.axis is not None and joint.axis.limit is not None ): continue lower, upper = joint.axis.limit.lower, joint.axis.limit.upper # Check if both limits are infinite (indicating original continuous joint) if not ( lower is not None and upper is not None and np.isinf(lower) and lower < 0 and np.isinf(upper) and upper > 0 ): continue # Restore as continuous joint joint.type = "continuous" # Create a new Limit object with only effort and velocity # (no position limits for continuous joints) joint.axis.limit = rod.Limit( effort=joint.axis.limit.effort, velocity=joint.axis.limit.velocity, lower=None, upper=None, ) # Export the URDF string urdf_string = UrdfExporter(pretty=True).to_urdf_string(sdf=rod_model_output) return urdf_string # ========== # Properties # ========== def name(self) -> str: """ Return the name of the model. Returns: The name of the model. """ return self.model_name def number_of_links(self) -> int: """ Return the number of links in the model. Returns: The number of links in the model. Note: The base link is included in the count and its index is always 0. """ return self.kin_dyn_parameters.number_of_links() def number_of_joints(self) -> int: """ Return the number of joints in the model. Returns: The number of joints in the model. """ return self.kin_dyn_parameters.number_of_joints() def number_of_frames(self) -> int: """ Return the number of frames in the model. Returns: The number of frames in the model. """ return self.kin_dyn_parameters.number_of_frames() # ================= # Base link methods # ================= def floating_base(self) -> bool: """ Return whether the model has a floating base. Returns: True if the model is floating-base, False otherwise. """ return self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6 def base_link(self) -> str: """ Return the name of the base link. Returns: The name of the base link. Note: By default, the base link is the root of the kinematic tree. """ return self.link_names()[0] # ===================== # Joint-related methods # ===================== def dofs(self) -> int: """ Return the number of degrees of freedom of the model. Returns: The number of degrees of freedom of the model. Note: We do not yet support multi-DoF joints, therefore this is always equal to the number of joints. In the future, this could be different. """ return sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:]) def joint_names(self) -> tuple[str, ...]: """ Return the names of the joints in the model. Returns: The names of the joints in the model. """ return self.kin_dyn_parameters.joint_model.joint_names[1:] # ==================== # Link-related methods # ==================== def link_names(self) -> tuple[str, ...]: """ Return the names of the links in the model. Returns: The names of the links in the model. """ return self.kin_dyn_parameters.link_names # ===================== # Frame-related methods # ===================== def frame_names(self) -> tuple[str, ...]: """ Return the names of the frames in the model. Returns: The names of the frames in the model. """ return self.kin_dyn_parameters.frame_parameters.name # ===================== # Model post-processing # ===================== def reduce( model: JaxSimModel, considered_joints: tuple[str, ...], locked_joint_positions: dict[str, jtp.FloatLike] | None = None, ) -> JaxSimModel: """ Reduce the model by lumping together the links connected by removed joints. Args: model: The model to reduce. considered_joints: The sequence of joints to consider. locked_joint_positions: A dictionary containing the positions of the joints to be considered in the reduction process. The removed joints in the reduced model will have their position locked to their value of this dictionary. If a joint is not part of the dictionary, its position is set to zero. """ locked_joint_positions = ( locked_joint_positions if locked_joint_positions is not None else {} ) # If locked joints are passed, make sure that they are valid. if not set(locked_joint_positions).issubset(model.joint_names()): new_joints = set(model.joint_names()) - set(locked_joint_positions) raise ValueError(f"Passed joints not existing in the model: {new_joints}") # Operate on a deep copy of the model description in order to prevent problems # when mutable attributes are updated. intermediate_description = copy.deepcopy(model.description) # Update the initial position of the joints. # This is necessary to compute the correct pose of the link pairs connected # to removed joints. for joint_name in set(model.joint_names()) - set(considered_joints): j = intermediate_description.joints_dict[joint_name] with j.mutable_context(): j.initial_position = locked_joint_positions.get(joint_name, 0.0) # Reduce the model description. # If `considered_joints` contains joints not existing in the model, # the method will raise an exception. reduced_intermediate_description = intermediate_description.reduce( considered_joints=list(considered_joints) ) # Build the reduced model. reduced_model = JaxSimModel.build( model_description=reduced_intermediate_description, model_name=model.name(), time_step=model.time_step, terrain=model.terrain, contact_model=model.contact_model, contact_params=model.contact_params, actuation_params=model.actuation_params, gravity=model.gravity, integrator=model.integrator, constraints=model.kin_dyn_parameters.constraints, ) with reduced_model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): # Store the origin of the model, in case downstream logic needs it. reduced_model.built_from = model.built_from # Compute the hw parametrization metadata of the reduced model # TODO: move the building of the metadata to KinDynParameters.build() # and use the model_description instead of model.built_from. reduced_model.kin_dyn_parameters.hw_link_metadata = ( reduced_model.compute_hw_link_metadata() ) return reduced_model # =================== # Inertial properties # =================== @jax.jit @js.common.named_scope def total_mass(model: JaxSimModel) -> jtp.Float: """ Compute the total mass of the model. Args: model: The model to consider. Returns: The total mass of the model. """ return model.kin_dyn_parameters.link_parameters.mass.sum().astype(float) @jax.jit @js.common.named_scope def link_spatial_inertia_matrices(model: JaxSimModel) -> jtp.Array: """ Compute the spatial 6D inertia matrices of all links of the model. Args: model: The model to consider. Returns: A 3D array containing the stacked spatial 6D inertia matrices of the links. """ return jax.vmap(js.kin_dyn_parameters.LinkParameters.spatial_inertia)( model.kin_dyn_parameters.link_parameters ) # ============================== # Rigid Body Dynamics Algorithms # ============================== def _adjoint_from_rotation_translation( rotation: jtp.Matrix, translation: jtp.Vector, ) -> jtp.Matrix: zeros = jnp.zeros_like(rotation) top_right = jnp.einsum("...ij,...jk->...ik", Skew.wedge(translation), rotation) return jnp.concatenate( [ jnp.concatenate([rotation, top_right], axis=-1), jnp.concatenate([zeros, rotation], axis=-1), ], axis=-2, ) def _inverse_adjoint_from_rotation_translation( rotation: jtp.Matrix, translation: jtp.Vector, ) -> jtp.Matrix: rotation_t = jnp.swapaxes(rotation, -1, -2) zeros = jnp.zeros_like(rotation_t) top_right = -jnp.einsum("...ij,...jk->...ik", rotation_t, Skew.wedge(translation)) return jnp.concatenate( [ jnp.concatenate([rotation_t, top_right], axis=-1), jnp.concatenate([zeros, rotation_t], axis=-1), ], axis=-2, ) def _apply_input_representation_to_jacobian( jacobian: jtp.Matrix, base_transform: jtp.Matrix, ) -> jtp.Matrix: transformed_base = jnp.einsum( "...ij,jk->...ik", jacobian[..., :, 0:6], base_transform, ) return jnp.concatenate([transformed_base, jacobian[..., :, 6:]], axis=-1) def _apply_input_representation_derivative_to_jacobian( jacobian: jtp.Matrix, base_transform_derivative: jtp.Matrix, ) -> jtp.Matrix: transformed_base = jnp.einsum( "...ij,jk->...ik", jacobian[..., :, 0:6], base_transform_derivative, ) return jnp.concatenate( [transformed_base, jnp.zeros_like(jacobian[..., :, 6:])], axis=-1, ) def _link_jacobian_support_mask( model: JaxSimModel, *, dtype: jnp.dtype, ) -> jtp.Matrix: κb = model.kin_dyn_parameters.support_body_array_bool return jnp.concatenate( [ jnp.ones((model.number_of_links(), 5), dtype=dtype), jnp.asarray(κb, dtype=dtype), ], axis=1, ) def _body_input_transform( data: js.data.JaxSimModelData, ) -> tuple[jtp.Matrix, jtp.Matrix]: base_transform = data._base_transform base_rotation = base_transform[0:3, 0:3] match data.velocity_representation: case VelRepr.Inertial: B_X_I = _inverse_adjoint_from_rotation_translation( rotation=base_rotation, translation=base_transform[0:3, 3], ) B_Ẋ_I = -B_X_I @ Cross.vx(data.base_velocity) case VelRepr.Body: B_X_I = jnp.eye(6, dtype=base_transform.dtype) B_Ẋ_I = jnp.zeros((6, 6), dtype=base_transform.dtype) case VelRepr.Mixed: B_X_I = _inverse_adjoint_from_rotation_translation( rotation=base_rotation, translation=jnp.zeros(3, dtype=base_transform.dtype), ) BW_v_BW_B = data.base_velocity.at[0:3].set( jnp.zeros(3, dtype=base_transform.dtype) ) B_Ẋ_I = -B_X_I @ Cross.vx(BW_v_BW_B) case _: raise ValueError(data.velocity_representation) return B_X_I, B_Ẋ_I def _link_output_adjoint_from_body( data: js.data.JaxSimModelData, B_H_L: jtp.Matrix, *, output_vel_repr: VelRepr, ) -> jtp.Matrix: base_transform = data._base_transform base_rotation = base_transform[0:3, 0:3] B_R_L = B_H_L[..., 0:3, 0:3] B_p_L = B_H_L[..., 0:3, 3] match output_vel_repr: case VelRepr.Inertial: return _adjoint_from_rotation_translation( rotation=base_rotation, translation=base_transform[0:3, 3], ) case VelRepr.Body: return _inverse_adjoint_from_rotation_translation( rotation=B_R_L, translation=B_p_L, ) case VelRepr.Mixed: W_p_B_in_LW = -jnp.einsum("ij,...j->...i", base_rotation, B_p_L) W_R_B = jnp.broadcast_to(base_rotation, B_R_L.shape) return _adjoint_from_rotation_translation( rotation=W_R_B, translation=W_p_B_in_LW, ) case _: raise ValueError(output_vel_repr) @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) def generalized_free_floating_jacobian( model: JaxSimModel, data: js.data.JaxSimModelData, *, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: """ Compute the free-floating jacobians of all links. Args: model: The model to consider. data: The data of the considered model. output_vel_repr: The output velocity representation of the free-floating jacobians. Returns: The `(nL, 6, 6+dofs)` array containing the stacked free-floating jacobians of the links. The first axis is the link index. Note: The v-stacked version of the returned Jacobian array together with the flattened 6D forces of the links, are useful to compute the `J.T @ f` product of the multi-body EoM. """ output_vel_repr = ( output_vel_repr if output_vel_repr is not None else data.velocity_representation ) # Compute the doubly-left free-floating full jacobian. B_J_full_WX_B, B_H_L = jaxsim.rbda.jacobian_full_doubly_left( model=model, joint_positions=data.joint_positions, ) support_mask = _link_jacobian_support_mask(model=model, dtype=B_J_full_WX_B.dtype) B_J_WL_B = support_mask[:, jnp.newaxis, :] * B_J_full_WX_B[jnp.newaxis, ...] B_X_I, _ = _body_input_transform(data=data) B_J_WL_I = _apply_input_representation_to_jacobian( jacobian=B_J_WL_B, base_transform=B_X_I, ) O_X_B = _link_output_adjoint_from_body( data=data, B_H_L=B_H_L, output_vel_repr=output_vel_repr, ) return jnp.einsum("...ij,...jk->...ik", O_X_B, B_J_WL_I) @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) def generalized_free_floating_jacobian_derivative( model: JaxSimModel, data: js.data.JaxSimModelData, *, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: """ Compute the free-floating jacobian derivatives of all links. Args: model: The model to consider. data: The data of the considered model. output_vel_repr: The output velocity representation of the free-floating jacobian derivatives. Returns: The `(nL, 6, 6+dofs)` array containing the stacked free-floating jacobian derivatives of the links. The first axis is the link index. """ output_vel_repr = ( output_vel_repr if output_vel_repr is not None else data.velocity_representation ) # Compute the derivative of the doubly-left free-floating full jacobian. B_J̇_full_WX_B, B_H_L = jaxsim.rbda.jacobian_derivative_full_doubly_left( model=model, joint_positions=data.joint_positions, joint_velocities=data.joint_velocities, ) # The derivative of the equation to change the input and output representations # of the Jacobian derivative needs the computation of the plain link Jacobian. B_J_full_WL_B, _ = jaxsim.rbda.jacobian_full_doubly_left( model=model, joint_positions=data.joint_positions, ) support_mask = _link_jacobian_support_mask(model=model, dtype=B_J̇_full_WX_B.dtype) B_J̇_WL_B = support_mask[:, jnp.newaxis, :] * B_J̇_full_WX_B[jnp.newaxis, ...] B_J_WL_B = support_mask[:, jnp.newaxis, :] * B_J_full_WL_B[jnp.newaxis, ...] B_X_I, B_Ẋ_I = _body_input_transform(data=data) B_J_WL_I = _apply_input_representation_to_jacobian( jacobian=B_J_WL_B, base_transform=B_X_I, ) B_J̇_WL_input = _apply_input_representation_to_jacobian( jacobian=B_J̇_WL_B, base_transform=B_X_I, ) B_J̇_WL_repr = _apply_input_representation_derivative_to_jacobian( jacobian=B_J_WL_B, base_transform_derivative=B_Ẋ_I, ) B_v_WB = B_X_I @ data.base_velocity B_ν = jnp.concatenate([B_v_WB, data.joint_velocities]) B_v_WL = jnp.einsum("bij,j->bi", B_J_WL_B, B_ν) O_X_B = _link_output_adjoint_from_body( data=data, B_H_L=B_H_L, output_vel_repr=output_vel_repr, ) match output_vel_repr: case VelRepr.Inertial: O_Ẋ_B = O_X_B @ Cross.vx(B_v_WB) case VelRepr.Body: B_v_B_L = B_v_WL - B_v_WB O_Ẋ_B = -jnp.einsum("...ij,...jk->...ik", O_X_B, Cross.vx(B_v_B_L)) case VelRepr.Mixed: base_rotation = data._base_transform[0:3, 0:3] B_p_L = B_H_L[..., 0:3, 3] W_p_B_in_LW = -jnp.einsum("ij,...j->...i", base_rotation, B_p_L) W_R_B = jnp.broadcast_to(base_rotation, B_H_L[..., 0:3, 0:3].shape) B_X_LW = _inverse_adjoint_from_rotation_translation( rotation=W_R_B, translation=W_p_B_in_LW, ) LW_v_WL = jnp.einsum("...ij,...j->...i", O_X_B, B_v_WL) LW_v_W_LW = LW_v_WL.at[..., 3:6].set(jnp.zeros_like(LW_v_WL[..., 3:6])) LW_v_LW_L = LW_v_WL - LW_v_W_LW LW_v_B_LW = LW_v_WL - jnp.einsum("...ij,j->...i", O_X_B, B_v_WB) - LW_v_LW_L O_Ẋ_B = -jnp.einsum( "...ij,...jk->...ik", O_X_B, Cross.vx(jnp.einsum("...ij,...j->...i", B_X_LW, LW_v_B_LW)), ) case _: raise ValueError(output_vel_repr) O_J̇_WL_I = jnp.einsum("...ij,...jk->...ik", O_Ẋ_B, B_J_WL_I) O_J̇_WL_I += jnp.einsum("...ij,...jk->...ik", O_X_B, B_J̇_WL_input) O_J̇_WL_I += jnp.einsum("...ij,...jk->...ik", O_X_B, B_J̇_WL_repr) return O_J̇_WL_I @functools.partial(jax.jit, static_argnames=["prefer_aba"]) def forward_dynamics( model: JaxSimModel, data: js.data.JaxSimModelData, *, joint_forces: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, prefer_aba: float = True, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute the forward dynamics of the model. Args: model: The model to consider. data: The data of the considered model. joint_forces: The joint forces to consider as a vector of shape `(dofs,)`. link_forces: The link 6D forces consider as a matrix of shape `(nL, 6)`. The frame in which they are expressed must be `data.velocity_representation`. prefer_aba: Whether to prefer the ABA algorithm over the CRB one. Returns: A tuple containing the 6D acceleration in the active representation of the base link and the joint accelerations resulting from the application of the considered joint forces and external forces. """ forward_dynamics_fn = forward_dynamics_aba if prefer_aba else forward_dynamics_crb return forward_dynamics_fn( model=model, data=data, joint_forces=joint_forces, link_forces=link_forces, ) @functools.partial(jax.jit, static_argnames=("parallel",)) @js.common.named_scope def forward_dynamics_aba( model: JaxSimModel, data: js.data.JaxSimModelData, *, joint_forces: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, parallel: bool = False, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute the forward dynamics of the model with the ABA algorithm. Args: model: The model to consider. data: The data of the considered model. joint_forces: The joint forces to consider as a vector of shape `(dofs,)`. link_forces: The link 6D forces to consider as a matrix of shape `(nL, 6)`. The frame in which they are expressed must be `data.velocity_representation`. parallel: If ``True``, use the level-parallel ABA implementation that processes independent tree branches simultaneously. Beneficial on GPU or for wide/deep kinematic trees. Returns: A tuple containing the 6D acceleration in the active representation of the base link and the joint accelerations resulting from the application of the considered joint forces and external forces. """ # ============ # Prepare data # ============ # Build joint forces, if not provided. τ = ( jnp.atleast_1d(joint_forces.squeeze()) if joint_forces is not None else jnp.zeros_like(data.joint_positions) ) # Build link forces, if not provided. f_L = ( jnp.atleast_2d(link_forces.squeeze()) if link_forces is not None else jnp.zeros((model.number_of_links(), 6)) ) # Create a references object that simplifies converting among representations. references = js.references.JaxSimModelReferences.build( model=model, joint_force_references=τ, link_forces=f_L, data=data, velocity_representation=data.velocity_representation, ) # Extract the state in inertial-fixed representation. with data.switch_velocity_representation(VelRepr.Inertial): W_p_B = data.base_position W_v_WB = data.base_velocity W_Q_B = data.base_orientation s = data.joint_positions ṡ = data.joint_velocities # Extract the inputs in inertial-fixed representation. W_f_L = references._link_forces τ = references._joint_force_references # ======================== # Compute forward dynamics # ======================== aba_fn = jaxsim.rbda.aba_parallel if parallel else jaxsim.rbda.aba W_v̇_WB, s̈ = aba_fn( model=model, base_position=W_p_B, base_quaternion=W_Q_B, joint_positions=s, base_linear_velocity=W_v_WB[0:3], base_angular_velocity=W_v_WB[3:6], joint_velocities=ṡ, joint_transforms=model.kin_dyn_parameters.joint_transforms( joint_positions=s, base_transform=data.base_transform, ), joint_forces=τ, link_forces=W_f_L, standard_gravity=model.gravity, ) # ============= # Adjust output # ============= def to_active( W_v̇_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WB: jtp.Vector, W_v_WC: jtp.Vector ) -> jtp.Vector: """ Convert the inertial-fixed apparent base acceleration W_v̇_WB to another representation C_v̇_WB expressed in a generic frame C. """ # In Mixed representation, we need to include a cross product in ℝ⁶. # In Inertial and Body representations, the cross product is always zero. C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) return C_X_W @ (W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB) match data.velocity_representation: case VelRepr.Inertial: # In this case C=W W_H_C = W_H_W = jnp.eye(4) # noqa: F841 W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 case VelRepr.Body: # In this case C=B W_H_C = W_H_B = data._base_transform W_v_WC = W_v_WB case VelRepr.Mixed: # In this case C=B[W] W_H_B = data._base_transform W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841 W_ṗ_B = data.base_velocity[0:3] W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841 case _: raise ValueError(data.velocity_representation) # We need to convert the derivative of the base velocity to the active # representation. In Mixed representation, this conversion is not a plain # transformation with just X, but it also involves a cross product in ℝ⁶. C_v̇_WB = to_active( W_v̇_WB=W_v̇_WB, W_H_C=W_H_C, W_v_WB=W_v_WB, W_v_WC=W_v_WC, ) # The ABA algorithm already returns a zero base 6D acceleration for # fixed-based models. However, the to_active function introduces an # additional acceleration component in Mixed representation. # Here below we make sure that the base acceleration is zero. C_v̇_WB = C_v̇_WB if model.floating_base() else jnp.zeros(6) return C_v̇_WB.astype(float), s̈.astype(float) @jax.jit @js.common.named_scope def forward_dynamics_crb( model: JaxSimModel, data: js.data.JaxSimModelData, *, joint_forces: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute the forward dynamics of the model with the CRB algorithm. Args: model: The model to consider. data: The data of the considered model. joint_forces: The joint forces to consider as a vector of shape `(dofs,)`. link_forces: The link 6D forces to consider as a matrix of shape `(nL, 6)`. The frame in which they are expressed must be `data.velocity_representation`. Returns: A tuple containing the 6D acceleration in the active representation of the base link and the joint accelerations resulting from the application of the considered joint forces and external forces. Note: Compared to ABA, this method could be significantly slower, especially for models with a large number of degrees of freedom. """ # ============ # Prepare data # ============ # Build joint torques if not provided. τ = ( jnp.atleast_1d(joint_forces) if joint_forces is not None else jnp.zeros_like(data.joint_positions) ) # Build external forces if not provided. f = ( jnp.atleast_2d(link_forces) if link_forces is not None else jnp.zeros(shape=(model.number_of_links(), 6)) ) # Compute terms of the floating-base EoM. M = free_floating_mass_matrix(model=model, data=data) h = free_floating_bias_forces(model=model, data=data) S = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T J = generalized_free_floating_jacobian(model=model, data=data) # TODO: invert the Mss block exploiting sparsity defined by the parent array λ(i) # ======================== # Compute forward dynamics # ======================== if model.floating_base(): # l: number of links. # g: generalized coordinates, 6 + number of joints. JTf = jnp.einsum("l6g,l6->g", J, f) ν̇ = jnp.linalg.solve(M, S @ τ - h + JTf) else: # l: number of links. # j: number of joints. JTf = jnp.einsum("l6j,l6->j", J[:, :, 6:], f) s̈ = jnp.linalg.solve(M[6:, 6:], τ - h[6:] + JTf) v̇_WB = jnp.zeros(6) ν̇ = jnp.hstack([v̇_WB, s̈.squeeze()]) # ============= # Adjust output # ============= # Extract the base acceleration in the active representation. # Note that this is an apparent acceleration (relevant in Mixed representation), # therefore it cannot be always expressed in different frames with just a # 6D transformation X. v̇_WB = ν̇[0:6].squeeze().astype(float) # Extract the joint accelerations. s̈ = jnp.atleast_1d(ν̇[6:].squeeze()).astype(float) return v̇_WB, s̈ @functools.partial(jax.jit, static_argnames=("parallel",)) @js.common.named_scope def forward_kinematics( model: JaxSimModel, data: js.data.JaxSimModelData, *, parallel: bool = False, ) -> jtp.Matrix: """ Compute the forward kinematics of the model. Args: model: The model to consider. data: The data of the considered model. parallel: If True, use the level-parallel FK implementation that processes independent tree branches simultaneously. Returns: The nL x 4 x 4 array containing the stacked homogeneous transformations of the links. The first axis is the link index. """ fk_fn = ( jaxsim.rbda.forward_kinematics_model_parallel if parallel else jaxsim.rbda.forward_kinematics_model ) # Recompute joint transforms from the model to ensure gradients # flow through model parameters. joint_transforms = model.kin_dyn_parameters.joint_transforms( joint_positions=data.joint_positions, base_transform=data.base_transform, ) W_H_LL, _ = fk_fn( model=model, base_position=data.base_position, base_quaternion=data.base_quaternion, joint_positions=data.joint_positions, joint_velocities=data.joint_velocities, base_linear_velocity_inertial=data._base_linear_velocity, base_angular_velocity_inertial=data._base_angular_velocity, joint_transforms=joint_transforms, ) return W_H_LL def _transform_M_block(M_body: jtp.Matrix, X: jtp.Matrix) -> jtp.Matrix: """ Apply invTᵀ M_body invT with invT = diag(X, I_n), without forming invT. Args: M_body: (6+n, 6+n) mass matrix (inverse) in body representation. X: (6, 6) adjoint (e.g. B_X_W or B_X_BW). Returns: M_repr: (6+n, 6+n) mass matrix (inverse) in the new representation. """ # invTᵀ M invT with invT = diag(X, I): # Mbb' = Xᵀ Mbb X # Mbj' = Xᵀ Mbj # Mjb' = Mjb X # Mjj' = Mjj Mbb_t = X.T @ M_body[:6, :6] @ X Mbj_t = X.T @ M_body[:6, 6:] Mjb_t = M_body[6:, :6] @ X Mjj_t = M_body[6:, 6:] top = jnp.concatenate([Mbb_t, Mbj_t], axis=1) bottom = jnp.concatenate([Mjb_t, Mjj_t], axis=1) return jnp.concatenate([top, bottom], axis=0) @jax.jit @js.common.named_scope def free_floating_mass_matrix( model: JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: """ Compute the free-floating mass matrix of the model with the CRBA algorithm. Args: model: The model to consider. data: The data of the considered model. Returns: The free-floating mass matrix of the model. """ M_body = jaxsim.rbda.crba( model=model, joint_positions=data.joint_positions, ) match data.velocity_representation: case VelRepr.Body: return M_body case VelRepr.Inertial: B_X_W = Adjoint.from_transform(transform=data.base_transform, inverse=True) return _transform_M_block(M_body, B_X_W) case VelRepr.Mixed: BW_H_B = data.base_transform.at[0:3, 3].set(jnp.zeros(3)) B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) return _transform_M_block(M_body, B_X_BW) case _: raise ValueError(data.velocity_representation) @jax.jit @js.common.named_scope def free_floating_mass_matrix_inverse( model: JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: """ Compute the inverse of the free-floating mass matrix of the model with the CRBA algorithm. Args: model: The model to consider. data: The data of the considered model. Returns: The inverse of the free-floating mass matrix of the model. """ M_inv_body = jaxsim.rbda.mass_inverse( model=model, joint_transforms=data._joint_transforms, ) match data.velocity_representation: case VelRepr.Body: return M_inv_body case VelRepr.Inertial: W_X_B = Adjoint.from_transform(transform=data.base_transform) return _transform_M_block(M_inv_body, W_X_B.T) case VelRepr.Mixed: B_H_BW = data.base_transform.at[0:3, 3].set(jnp.zeros(3)) BW_X_B = Adjoint.from_transform(transform=B_H_BW) return _transform_M_block(M_inv_body, BW_X_B.T) case _: raise ValueError(data.velocity_representation) @jax.jit @js.common.named_scope def free_floating_coriolis_matrix( model: JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: """ Compute the free-floating Coriolis matrix of the model. Args: model: The model to consider. data: The data of the considered model. Returns: The free-floating Coriolis matrix of the model. Note: This function, contrarily to other quantities of the equations of motion, does not exploit any iterative algorithm. Therefore, the computation of the Coriolis matrix may be much slower than other quantities. """ # We perform all the calculation in body-fixed. # The Coriolis matrix computed in this representation is converted later # to the active representation stored in data. with data.switch_velocity_representation(VelRepr.Body): B_ν = data.generalized_velocity # Doubly-left free-floating Jacobian. L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data) # Doubly-left free-floating Jacobian derivative. L_J̇_WL_B = generalized_free_floating_jacobian_derivative( model=model, data=data ) L_M_L = link_spatial_inertia_matrices(model=model) # Body-fixed link velocities. # Note: we could have called link.velocity() instead of computing it ourselves, # but since we need the link Jacobians later, we can save a double calculation. L_v_WL = jax.vmap(lambda J: J @ B_ν)(L_J_WL_B) # Compute the contribution of each link to the Coriolis matrix. def compute_link_contribution(M, v, J, J̇) -> jtp.Array: return J.T @ ((Cross.vx_star(v) @ M + M @ Cross.vx(v)) @ J + M @ J̇) C_B_links = jax.vmap(compute_link_contribution)( L_M_L, L_v_WL, L_J_WL_B, L_J̇_WL_B, ) # We need to adjust the Coriolis matrix for fixed-base models. # In this case, the base link does not contribute to the matrix, and we need to zero # the off-diagonal terms mapping joint quantities onto the base configuration. if model.floating_base(): C_B = C_B_links.sum(axis=0) else: C_B = C_B_links[1:].sum(axis=0) C_B = C_B.at[0:6, 6:].set(0.0) C_B = C_B.at[6:, 0:6].set(0.0) # Adjust the representation of the Coriolis matrix. # Refer to https://github.com/traversaro/traversaro-phd-thesis, Section 3.6. match data.velocity_representation: case VelRepr.Body: return C_B case VelRepr.Inertial: n = model.dofs() W_H_B = data._base_transform B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True) B_T_W = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(n)) with data.switch_velocity_representation(VelRepr.Inertial): W_v_WB = data.base_velocity B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB) B_Ṫ_W = jax.scipy.linalg.block_diag(B_Ẋ_W, jnp.zeros(shape=(n, n))) with data.switch_velocity_representation(VelRepr.Body): M = free_floating_mass_matrix(model=model, data=data) C = B_T_W.T @ (M @ B_Ṫ_W + C_B @ B_T_W) return C case VelRepr.Mixed: n = model.dofs() BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3)) B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True) B_T_BW = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(n)) with data.switch_velocity_representation(VelRepr.Mixed): BW_v_WB = data.base_velocity BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) BW_v_BW_B = BW_v_WB - BW_v_W_BW B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B) B_Ṫ_BW = jax.scipy.linalg.block_diag(B_Ẋ_BW, jnp.zeros(shape=(n, n))) with data.switch_velocity_representation(VelRepr.Body): M = free_floating_mass_matrix(model=model, data=data) C = B_T_BW.T @ (M @ B_Ṫ_BW + C_B @ B_T_BW) return C case _: raise ValueError(data.velocity_representation) @jax.jit @js.common.named_scope def inverse_dynamics( model: JaxSimModel, data: js.data.JaxSimModelData, *, joint_accelerations: jtp.VectorLike | None = None, base_acceleration: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute inverse dynamics with the RNEA algorithm. Args: model: The model to consider. data: The data of the considered model. joint_accelerations: The joint accelerations to consider as a vector of shape `(dofs,)`. base_acceleration: The base acceleration to consider as a vector of shape `(6,)`. link_forces: The link 6D forces to consider as a matrix of shape `(nL, 6)`. The frame in which they are expressed must be `data.velocity_representation`. Returns: A tuple containing the 6D force in the active representation applied to the base to obtain the considered base acceleration, and the joint forces to apply to obtain the considered joint accelerations. """ # ============ # Prepare data # ============ # Build joint accelerations, if not provided. s̈ = ( jnp.atleast_1d(jnp.array(joint_accelerations).squeeze()) if joint_accelerations is not None else jnp.zeros_like(data.joint_positions) ) # Build base acceleration, if not provided. v̇_WB = ( jnp.array(base_acceleration).squeeze() if base_acceleration is not None else jnp.zeros(6) ) # Build link forces, if not provided. f_L = ( jnp.atleast_2d(jnp.array(link_forces).squeeze()) if link_forces is not None else jnp.zeros(shape=(model.number_of_links(), 6)) ) def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): """ Convert the active representation of the base acceleration C_v̇_WB expressed in a generic frame C to the inertial-fixed representation W_v̇_WB. """ W_X_C = Adjoint.from_transform(transform=W_H_C) C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) C_v_WC = C_X_W @ W_v_WC # In Mixed representation, we need to include a cross product in ℝ⁶. # In Inertial and Body representations, the cross product is always zero. return W_X_C @ (C_v̇_WB + Cross.vx(C_v_WC) @ C_v_WB) match data.velocity_representation: case VelRepr.Inertial: W_H_C = W_H_W = jnp.eye(4) # noqa: F841 W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 case VelRepr.Body: W_H_C = W_H_B = data._base_transform with data.switch_velocity_representation(VelRepr.Inertial): W_v_WC = W_v_WB = data.base_velocity case VelRepr.Mixed: W_H_B = data._base_transform W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841 W_ṗ_B = data.base_velocity[0:3] W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841 case _: raise ValueError(data.velocity_representation) # We need to convert the derivative of the base acceleration to the Inertial # representation. In Mixed representation, this conversion is not a plain # transformation with just X, but it also involves a cross product in ℝ⁶. W_v̇_WB = to_inertial( C_v̇_WB=v̇_WB, W_H_C=W_H_C, C_v_WB=data.base_velocity, W_v_WC=W_v_WC, ) # Create a references object that simplifies converting among representations. references = js.references.JaxSimModelReferences.build( model=model, data=data, link_forces=f_L, velocity_representation=data.velocity_representation, ) # Extract the state in inertial-fixed representation. with data.switch_velocity_representation(VelRepr.Inertial): W_p_B = data.base_position W_v_WB = data.base_velocity W_Q_B = data.base_quaternion s = data.joint_positions ṡ = data.joint_velocities # Extract the inputs in inertial-fixed representation. W_f_L = references._link_forces # ======================== # Compute inverse dynamics # ======================== W_f_B, τ = jaxsim.rbda.rnea( model=model, base_position=W_p_B, base_quaternion=W_Q_B, joint_positions=s, base_linear_velocity=W_v_WB[0:3], base_angular_velocity=W_v_WB[3:6], joint_velocities=ṡ, base_linear_acceleration=W_v̇_WB[0:3], base_angular_acceleration=W_v̇_WB[3:6], joint_accelerations=s̈, joint_transforms=data._joint_transforms, link_forces=W_f_L, standard_gravity=model.gravity, ) # ============= # Adjust output # ============= # Express W_f_B in the active representation. f_B = js.data.JaxSimModelData.inertial_to_other_representation( array=W_f_B, other_representation=data.velocity_representation, transform=data._base_transform, is_force=True, ).squeeze() return f_B.astype(float), τ.astype(float) @jax.jit @js.common.named_scope def free_floating_gravity_forces( model: JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Vector: r""" Compute the free-floating gravity forces :math:`g(\mathbf{q})` of the model. Args: model: The model to consider. data: The data of the considered model. Returns: The free-floating gravity forces of the model. """ # Build a new state with zeroed velocities. data_rnea = js.data.JaxSimModelData.build( model=model, velocity_representation=data.velocity_representation, base_position=data.base_position, base_quaternion=data.base_quaternion, joint_positions=data.joint_positions, ) return jnp.hstack( inverse_dynamics( model=model, data=data_rnea, # Set zero inputs: joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())), base_acceleration=jnp.zeros(6), link_forces=jnp.zeros(shape=(model.number_of_links(), 6)), ) ).astype(float) @jax.jit @js.common.named_scope def free_floating_bias_forces( model: JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Vector: r""" Compute the free-floating bias forces :math:`h(\mathbf{q}, \boldsymbol{\nu})` of the model. Args: model: The model to consider. data: The data of the considered model. Returns: The free-floating bias forces of the model. """ # Set the generalized position and generalized velocity. base_linear_velocity, base_angular_velocity = None, None if model.floating_base(): base_velocity = data.base_velocity base_linear_velocity = base_velocity[:3] base_angular_velocity = base_velocity[3:] data_rnea = js.data.JaxSimModelData.build( model=model, velocity_representation=data.velocity_representation, base_position=data.base_position, base_quaternion=data.base_quaternion, joint_positions=data.joint_positions, joint_velocities=data.joint_velocities, base_linear_velocity=base_linear_velocity, base_angular_velocity=base_angular_velocity, ) return jnp.hstack( inverse_dynamics( model=model, data=data_rnea, # Set zero inputs: joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())), base_acceleration=jnp.zeros(6), link_forces=jnp.zeros(shape=(model.number_of_links(), 6)), ) ).astype(float) # ========================== # Other kinematic quantities # ========================== @jax.jit @js.common.named_scope def locked_spatial_inertia( model: JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: """ Compute the locked 6D inertia matrix of the model. Args: model: The model to consider. data: The data of the considered model. Returns: The locked 6D inertia matrix of the model. """ return total_momentum_jacobian(model=model, data=data)[:, 0:6] @jax.jit @js.common.named_scope def total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector: """ Compute the total momentum of the model. Args: model: The model to consider. data: The data of the considered model. Returns: The total momentum of the model in the active velocity representation. """ ν = data.generalized_velocity Jh = total_momentum_jacobian(model=model, data=data) return Jh @ ν @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) def total_momentum_jacobian( model: JaxSimModel, data: js.data.JaxSimModelData, *, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: """ Compute the jacobian of the total momentum. Args: model: The model to consider. data: The data of the considered model. output_vel_repr: The output velocity representation of the jacobian. Returns: The jacobian of the total momentum of the model in the active representation. """ output_vel_repr = ( output_vel_repr if output_vel_repr is not None else data.velocity_representation ) if output_vel_repr is data.velocity_representation: return free_floating_mass_matrix(model=model, data=data)[0:6] with data.switch_velocity_representation(VelRepr.Body): B_Jh_B = free_floating_mass_matrix(model=model, data=data)[0:6] match data.velocity_representation: case VelRepr.Body: B_Jh = B_Jh_B case VelRepr.Inertial: B_X_W = Adjoint.from_transform(transform=data._base_transform, inverse=True) B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) case VelRepr.Mixed: BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3)) B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) case _: raise ValueError(data.velocity_representation) match output_vel_repr: case VelRepr.Body: return B_Jh case VelRepr.Inertial: W_H_B = data._base_transform B_Xv_W = Adjoint.from_transform(transform=W_H_B, inverse=True) W_Xf_B = B_Xv_W.T W_Jh = W_Xf_B @ B_Jh return W_Jh case VelRepr.Mixed: BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3)) B_Xv_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) BW_Xf_B = B_Xv_BW.T BW_Jh = BW_Xf_B @ B_Jh return BW_Jh case _: raise ValueError(output_vel_repr) @jax.jit @js.common.named_scope def average_velocity(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector: """ Compute the average velocity of the model. Args: model: The model to consider. data: The data of the considered model. Returns: The average velocity of the model computed in the base frame and expressed in the active representation. """ ν = data.generalized_velocity J = average_velocity_jacobian(model=model, data=data) return J @ ν @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) def average_velocity_jacobian( model: JaxSimModel, data: js.data.JaxSimModelData, *, output_vel_repr: VelRepr | None = None, ) -> jtp.Matrix: """ Compute the Jacobian of the average velocity of the model. Args: model: The model to consider. data: The data of the considered model. output_vel_repr: The output velocity representation of the jacobian. Returns: The Jacobian of the average centroidal velocity of the model in the desired representation. """ output_vel_repr = ( output_vel_repr if output_vel_repr is not None else data.velocity_representation ) # Depending on the velocity representation, the frame G is either G[W] or G[B]. G_J = js.com.average_centroidal_velocity_jacobian(model=model, data=data) match output_vel_repr: case VelRepr.Inertial: GW_J = G_J W_p_CoM = js.com.com_position(model=model, data=data) W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) W_X_GW = Adjoint.from_transform(transform=W_H_GW) return W_X_GW @ GW_J case VelRepr.Body: GB_J = G_J W_p_B = data.base_position W_p_CoM = js.com.com_position(model=model, data=data) B_R_W = jaxsim.math.Quaternion.to_dcm(data.base_orientation).transpose() B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B)) B_X_GB = Adjoint.from_transform(transform=B_H_GB) return B_X_GB @ GB_J case VelRepr.Mixed: GW_J = G_J W_p_B = data.base_position W_p_CoM = js.com.com_position(model=model, data=data) BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B) BW_X_GW = Adjoint.from_transform(transform=BW_H_GW) return BW_X_GW @ GW_J # ======================== # Other dynamic quantities # ======================== @jax.jit @js.common.named_scope def link_bias_accelerations( model: JaxSimModel, data: js.data.JaxSimModelData, ) -> jtp.Vector: r""" Compute the bias accelerations of the links of the model. Args: model: The model to consider. data: The data of the considered model. Returns: The bias accelerations of the links of the model. Note: This function computes the component of the total 6D acceleration not due to the joint or base acceleration. It is often called :math:`\dot{J} \boldsymbol{\nu}`. """ # ================================================ # Compute the body-fixed zero base 6D acceleration # ================================================ # Compute the base transform. W_H_B = data._base_transform def other_representation_to_inertial( C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector ) -> jtp.Vector: """ Convert the active representation of the base acceleration C_v̇_WB expressed in a generic frame C to the inertial-fixed representation W_v̇_WB. """ W_X_C = Adjoint.from_transform(transform=W_H_C) C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True) # In Mixed representation, we need to include a cross product in ℝ⁶. # In Inertial and Body representations, the cross product is always zero. return W_X_C @ (C_v̇_WB + jaxsim.math.Cross.vx(C_X_W @ W_v_WC) @ C_v_WB) # Here we initialize a zero 6D acceleration in the active representation, and # convert it to inertial-fixed. This is a useful intermediate representation # because the apparent acceleration W_v̇_WB is equal to the intrinsic acceleration # W_a_WB, and intrinsic accelerations can be expressed in different frames through # a simple C_X_W 6D transform. match data.velocity_representation: case VelRepr.Inertial: W_H_C = W_H_W = jnp.eye(4) # noqa: F841 W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 with data.switch_velocity_representation(VelRepr.Inertial): C_v_WB = W_v_WB = data.base_velocity case VelRepr.Body: W_H_C = W_H_B with data.switch_velocity_representation(VelRepr.Inertial): W_v_WC = W_v_WB = data.base_velocity # noqa: F841 with data.switch_velocity_representation(VelRepr.Body): C_v_WB = B_v_WB = data.base_velocity case VelRepr.Mixed: W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) W_H_C = W_H_BW with data.switch_velocity_representation(VelRepr.Mixed): W_ṗ_B = data.base_velocity[0:3] BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW) W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW # noqa: F841 with data.switch_velocity_representation(VelRepr.Mixed): C_v_WB = BW_v_WB = data.base_velocity # noqa: F841 case _: raise ValueError(data.velocity_representation) # Convert a zero 6D acceleration from the active representation to inertial-fixed. W_v̇_WB = other_representation_to_inertial( C_v̇_WB=jnp.zeros(6), C_v_WB=C_v_WB, W_H_C=W_H_C, W_v_WC=W_v_WC ) # =================================== # Initialize buffers and prepare data # =================================== # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array # Compute 6D transforms of the base velocity. B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True) # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. # Ensure cached transforms stay on device when indexed with traced `i`. i_X_λi = jnp.asarray(data._joint_transforms) # Extract the joint motion subspaces. S = model.kin_dyn_parameters.motion_subspaces # Allocate the buffer to store the body-fixed link velocities. L_v_WL = jnp.zeros(shape=(model.number_of_links(), 6)) # Store the base velocity. with data.switch_velocity_representation(VelRepr.Body): B_v_WB = data.base_velocity L_v_WL = L_v_WL.at[0].set(B_v_WB) # Get the joint velocities. ṡ = data.joint_velocities # Allocate the buffer to store the body-fixed link accelerations, # and initialize the base acceleration. L_v̇_WL = jnp.zeros(shape=(model.number_of_links(), 6)) L_v̇_WL = L_v̇_WL.at[0].set(B_X_W @ W_v̇_WB) # ====================================== # Propagate accelerations and velocities # ====================================== # The computation of the bias forces is similar to the forward pass of RNEA, # this time with zero base and joint accelerations. Furthermore, here we do # not remove gravity during the propagation. # Initialize the loop. Carry = tuple[jtp.Matrix, jtp.Matrix] carry0: Carry = (L_v_WL, L_v̇_WL) def propagate_accelerations(carry: Carry, i: jtp.Int) -> tuple[Carry, None]: # Initialize index and unpack the carry. ii = i - 1 v, a = carry # Get the motion subspace of the joint. Si = S[i].squeeze() # Project the joint velocity into its motion subspace. vJ = Si * ṡ[ii] # Propagate the link body-fixed velocity. v_i = i_X_λi[i] @ v[λ[i]] + vJ v = v.at[i].set(v_i) # Propagate the link body-fixed acceleration considering zero joint acceleration. s̈ = 0.0 a_i = i_X_λi[i] @ a[λ[i]] + Si * s̈ + jaxsim.math.Cross.vx(v[i]) @ vJ a = a.at[i].set(a_i) return (v, a), None # Compute the body-fixed velocity and body-fixed apparent acceleration of the links. (L_v_WL, L_v̇_WL), _ = ( jax.lax.scan( f=propagate_accelerations, init=carry0, xs=jnp.arange(start=1, stop=model.number_of_links()), ) if model.number_of_links() > 1 else [(L_v_WL, L_v̇_WL), None] ) # =================================================================== # Convert the body-fixed 6D acceleration to the active representation # =================================================================== def body_to_other_representation( L_v̇_WL: jtp.Vector, L_v_WL: jtp.Vector, C_H_L: jtp.Matrix, L_v_CL: jtp.Vector ) -> jtp.Vector: """ Convert the body-fixed apparent acceleration L_v̇_WL to another representation C_v̇_WL expressed in a generic frame C. """ # In Mixed representation, we need to include a cross product in ℝ⁶. # In Inertial and Body representations, the cross product is always zero. C_X_L = jaxsim.math.Adjoint.from_transform(transform=C_H_L) return C_X_L @ (L_v̇_WL + jaxsim.math.Cross.vx(L_v_CL) @ L_v_WL) match data.velocity_representation: case VelRepr.Body: C_H_L = L_H_L = jnp.stack( # noqa: F841 [jnp.eye(4)] * model.number_of_links() ) L_v_CL = L_v_LL = jnp.zeros( # noqa: F841 shape=(model.number_of_links(), 6) ) case VelRepr.Inertial: C_H_L = W_H_L = data._link_transforms L_v_CL = L_v_WL case VelRepr.Mixed: W_H_L = data._link_transforms LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L) C_H_L = LW_H_L L_v_CL = L_v_LW_L = jax.vmap( # noqa: F841 lambda v: v.at[0:3].set(jnp.zeros(3)) )(L_v_WL) case _: raise ValueError(data.velocity_representation) # Convert from body-fixed to the active representation. O_v̇_WL = jax.vmap(body_to_other_representation)( L_v̇_WL=L_v̇_WL, L_v_WL=L_v_WL, C_H_L=C_H_L, L_v_CL=L_v_CL ) return O_v̇_WL @jax.jit def joint_transforms( model: JaxSimModel, joint_positions: jtp.VectorLike, base_transform: jtp.MatrixLike ) -> jtp.Array: r""" Return the transforms of the joints. Args: model: The model to consider. joint_positions: The joint positions. base_transform: The homogeneous matrix defining the base pose. Returns: The stacked transforms :math:`{}^{i} \mathbf{H}_{\lambda(i)}(s)` of each joint. """ return model.kin_dyn_parameters.joint_transforms( joint_positions=joint_positions, base_transform=base_transform, ) # ====== # Energy # ====== @jax.jit @js.common.named_scope def mechanical_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float: """ Compute the mechanical energy of the model. Args: model: The model to consider. data: The data of the considered model. Returns: The mechanical energy of the model. """ K = kinetic_energy(model=model, data=data) U = potential_energy(model=model, data=data) return (K + U).astype(float) @jax.jit @js.common.named_scope def kinetic_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float: """ Compute the kinetic energy of the model. Args: model: The model to consider. data: The data of the considered model. Returns: The kinetic energy of the model. """ with data.switch_velocity_representation(velocity_representation=VelRepr.Body): B_ν = data.generalized_velocity M_B = free_floating_mass_matrix(model=model, data=data) K = 0.5 * B_ν.T @ M_B @ B_ν return K.squeeze().astype(float) @jax.jit @js.common.named_scope def potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float: """ Compute the potential energy of the model. Args: model: The model to consider. data: The data of the considered model. Returns: The potential energy of the model. """ m = total_mass(model=model) W_p̃_CoM = jnp.hstack([js.com.com_position(model=model, data=data), 1]) return jnp.sum((m * W_p̃_CoM)[2] * model.gravity) # =================== # Hw parametrization # =================== @jax.jit @js.common.named_scope def update_hw_parameters( model: JaxSimModel, scaling_factors: ScalingFactors ) -> JaxSimModel: """ Update the hardware parameters of the model by scaling the parameters of the links. This function applies scaling factors to the hardware metadata of the links, updating their shape, dimensions, density, and other related parameters. It recalculates the mass and inertia tensors of the links based on the updated metadata and adjusts the joint model transforms accordingly. Args: model: The JaxSimModel object to update. scaling_factors: A ScalingFactors object containing scaling factors for dimensions and density of the links. Returns: The updated JaxSimModel object with modified hardware parameters. """ kin_dyn_params: KinDynParameters = model.kin_dyn_parameters link_parameters: LinkParameters = kin_dyn_params.link_parameters hw_link_metadata: HwLinkMetadata = kin_dyn_params.hw_link_metadata has_joints = model.number_of_joints() > 0 def apply_scaling_single_link( link_shape, geometry, density, L_H_G, L_H_vis, L_H_pre, L_H_pre_mask, scaling_dims, scaling_density, ): """Apply scaling to a single link's numerical data.""" def scale_supported(_): shape_indices_map = jnp.array([[0, 1, 2], [0, 0, 1], [0, 0, 0], [0, 1, 2]]) per_link_indices = shape_indices_map[link_shape] scale_vector = scaling_dims[per_link_indices] # Update kinematics G_H_L = jaxsim.math.Transform.inverse(L_H_G) G_H_vis = G_H_L @ L_H_vis G_H̅_vis = G_H_vis.at[:3, 3].set(scale_vector * G_H_vis[:3, 3]) L_H̅_G = L_H_G.at[:3, 3].set(scale_vector * L_H_G[:3, 3]) L_H̅_vis = L_H̅_G @ G_H̅_vis # Update shape parameters updated_geom = geometry * scaling_dims updated_dens = density * scaling_density return updated_geom, updated_dens, L_H̅_G, L_H̅_vis, scale_vector def scale_unsupported(_): return ( geometry, density, L_H_G, L_H_vis, jnp.ones_like(scaling_dims), ) return jax.lax.cond( link_shape == LinkParametrizableShape.Unsupported, scale_unsupported, scale_supported, operand=None, ) # Vmap over all links for basic scaling ( updated_geometry, updated_density, updated_L_H_G, updated_L_H_vis, scale_vectors, ) = jax.vmap(apply_scaling_single_link)( hw_link_metadata.link_shape, hw_link_metadata.geometry, hw_link_metadata.density, hw_link_metadata.L_H_G, hw_link_metadata.L_H_vis, hw_link_metadata.L_H_pre, hw_link_metadata.L_H_pre_mask, scaling_factors.dims, scaling_factors.density, ) # Handle joint transforms separately, only if model has joints def transform_all_joints(operands): """Transform all joint poses across all links.""" original_L_H_G, updated_L_H_G, scale_vectors, L_H_pre, L_H_pre_mask = operands # Vectorized transformation: (n_links, n_joints, 4, 4) # Express joint transforms in the original CoM frames. # Using the already-scaled L_H_G here introduces a second implicit # scaling term and distorts kinematic chain proportions. G_H_L_all = jax.vmap(jaxsim.math.Transform.inverse)( original_L_H_G ) # (n_links, 4, 4) # Use batch matrix multiply with broadcasting # G_H_L_all: (n_links, 4, 4) -> (n_links, 1, 4, 4) # L_H_pre: (n_links, n_joints, 4, 4) # Result: (n_links, n_joints, 4, 4) G_H_pre = G_H_L_all[:, None, :, :] @ L_H_pre # Scale translation components G_H̅_pre = G_H_pre.at[:, :, :3, 3].set( jnp.where( L_H_pre_mask[:, :, None], scale_vectors[:, None, :] * G_H_pre[:, :, :3, 3], G_H_pre[:, :, :3, 3], ) ) # Transform back to link frames # updated_L_H_G: (n_links, 4, 4) -> (n_links, 1, 4, 4) # G_H̅_pre: (n_links, n_joints, 4, 4) # Result: (n_links, n_joints, 4, 4) return updated_L_H_G[:, None, :, :] @ G_H̅_pre updated_L_H_pre = jax.lax.cond( has_joints, transform_all_joints, lambda operands: operands[3], # Return L_H_pre unchanged operand=( hw_link_metadata.L_H_G, updated_L_H_G, scale_vectors, hw_link_metadata.L_H_pre, hw_link_metadata.L_H_pre_mask, ), ) # Create updated HwLinkMetadata updated_hw_link_metadata = hw_link_metadata.replace( geometry=updated_geometry, density=updated_density, L_H_G=updated_L_H_G, L_H_vis=updated_L_H_vis, L_H_pre=updated_L_H_pre, ) # Compute mass and inertia once and unpack the results m_updated, I_com_updated = HwLinkMetadata.compute_mass_and_inertia( updated_hw_link_metadata ) # Rotate the inertia tensor at CoM with the link orientation, and store # it in KynDynParameters. I_L_updated = jax.vmap( lambda metadata, I_com: metadata.L_H_G[:3, :3] @ I_com @ metadata.L_H_G[:3, :3].T )(updated_hw_link_metadata, I_com_updated) # Update link parameters updated_link_parameters = link_parameters.replace( mass=m_updated, inertia_elements=jax.vmap(LinkParameters.flatten_inertia_tensor)(I_L_updated), center_of_mass=jax.vmap(lambda metadata: metadata.L_H_G[:3, 3])( updated_hw_link_metadata ), ) if kin_dyn_params.contact_parameters.body: # Compute the contact parameters points = HwLinkMetadata.compute_contact_points( original_contact_params=kin_dyn_params.contact_parameters, link_shapes=updated_hw_link_metadata.link_shape, original_com_positions=link_parameters.center_of_mass, updated_com_positions=updated_link_parameters.center_of_mass, scaling_factors=scaling_factors, ) # Update contact parameters updated_contact_parameters = kin_dyn_params.contact_parameters.replace( point=points ) else: updated_contact_parameters = kin_dyn_params.contact_parameters # Update joint model transforms (λ_H_pre) def update_λ_H_pre(joint_index): # Extract the transforms and masks for the current joint index across all links L_H_pre_for_joint = updated_hw_link_metadata.L_H_pre[:, joint_index] L_H_pre_mask_for_joint = updated_hw_link_metadata.L_H_pre_mask[:, joint_index] # Select the first valid transform (if any) using the mask first_valid_index = jnp.argmax(L_H_pre_mask_for_joint) selected_transform = L_H_pre_for_joint[first_valid_index] # Check if any valid transform exists has_valid_transform = L_H_pre_mask_for_joint.any() # Fallback to the original λ_H_pre if no valid transform exists fallback_transform = kin_dyn_params.joint_model.λ_H_pre[joint_index + 1] # Return the selected transform or fallback return jnp.where(has_valid_transform, selected_transform, fallback_transform) if has_joints: # Apply the update function to all joint indices updated_λ_H_pre = jax.vmap(update_λ_H_pre)( jnp.arange(kin_dyn_params.number_of_joints()) ) # NOTE: λ_H_pre should be of len (1+n_joints) with the 0-th element equal # to identity to represent the world-to-base tree transform. See JointModel class updated_λ_H_pre_with_base = jnp.concatenate( (jnp.eye(4).reshape(1, 4, 4), updated_λ_H_pre), axis=0 ) # Replace the joint model with the updated transforms updated_joint_model = kin_dyn_params.joint_model.replace( λ_H_pre=updated_λ_H_pre_with_base ) else: # If there are no joints, we can just use the identity transform updated_joint_model = kin_dyn_params.joint_model # Replace the kin_dyn_parameters with updated values updated_kin_dyn_params = kin_dyn_params.replace( link_parameters=updated_link_parameters, contact_parameters=updated_contact_parameters, hw_link_metadata=updated_hw_link_metadata, joint_model=updated_joint_model, ) # Return the updated model return model.replace(kin_dyn_parameters=updated_kin_dyn_params) # ========== # Simulation # ========== @jax.jit @js.common.named_scope def step( model: JaxSimModel, data: js.data.JaxSimModelData, *, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, ) -> js.data.JaxSimModelData: """ Perform a simulation step. Args: model: The model to consider. data: The data of the considered model. dt: The time step to consider. If not specified, it is read from the model. link_forces: The 6D forces to apply to the links expressed in same representation of data. joint_force_references: The joint force references to consider. Returns: The new data of the model after the simulation step. Note: In order to reduce the occurrences of frame conversions performed internally, it is recommended to use inertial-fixed velocity representation. This can be particularly useful for automatically differentiated logic. """ # TODO: some contact models here may want to perform a dynamic filtering of # the enabled collidable points # Extract the inputs O_f_L_external = jnp.atleast_2d( jnp.array(link_forces, dtype=float).squeeze() if link_forces is not None else jnp.zeros((model.number_of_links(), 6)) ) # Get the external forces in inertial-fixed representation. W_f_L_external = js.data.JaxSimModelData.other_representation_to_inertial( O_f_L_external, other_representation=data.velocity_representation, transform=data._link_transforms, is_force=True, ) τ_references = jnp.atleast_1d( jnp.array(joint_force_references, dtype=float).squeeze() if joint_force_references is not None else jnp.zeros(model.dofs()) ) # ================================ # Compute the total joint torques # ================================ τ_total = js.actuation_model.compute_resultant_torques( model, data, joint_force_references=τ_references ) # ============================= # Advance the simulation state # ============================= from .integrators import _INTEGRATORS_MAP integrator_fn = _INTEGRATORS_MAP[model.integrator] data_tf = integrator_fn( model=model, data=data, link_forces=W_f_L_external, joint_torques=τ_total, ) data_tf = model.contact_model.update_velocity_after_impact( model=model, data=data_tf ) return data_tf ================================================ FILE: src/jaxsim/api/ode.py ================================================ import jax import jax.numpy as jnp import jaxsim.api as js import jaxsim.typing as jtp from jaxsim.math import Quaternion, Skew from jaxsim.rbda.kinematic_constraints import compute_constraint_wrenches from .common import VelRepr # ================================== # Functions defining system dynamics # ================================== def system_acceleration( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_forces: jtp.MatrixLike | None = None, joint_torques: jtp.VectorLike | None = None, ) -> tuple[jtp.Vector, jtp.Vector, dict[str, jtp.PyTree]]: """ Compute the system acceleration in the active representation. Args: model: The model to consider. data: The data of the considered model. link_forces: The 6D forces to apply to the links expressed in the same velocity representation of data. joint_torques: The joint torques applied to the joints. Returns: A tuple containing the base 6D acceleration in the active representation, the joint accelerations, and the contact state. """ # ==================== # Validate input data # ==================== # Build link forces if not provided. f_L = ( jnp.atleast_2d(link_forces.squeeze()) if link_forces is not None else jnp.zeros((model.number_of_links(), 6)) ).astype(float) # ====================== # Compute contact forces # ====================== W_f_L_terrain = jnp.zeros_like(f_L) contact_state = data.contact_state if len(model.kin_dyn_parameters.contact_parameters.body) > 0: # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact # with the terrain. W_f_L_terrain, contact_state_derivative = js.contact.link_contact_forces( model=model, data=data, link_forces=f_L, joint_torques=joint_torques, ) # Update the contact state data. This is necessary only for the contact models # that require propagation and integration of contact state. contact_state = model.contact_model.update_contact_state( contact_state_derivative ) # ================================== # Compute kinematic constraint forces # ================================== # Sum up all the forces: external + contact W_f_L_total = f_L + W_f_L_terrain # Compute the 6D forces W_f ∈ ℝ^{n_constraints × 2 × 6} applied to links due to # kinematic constraints. W_f_L_constraints = compute_constraint_wrenches( model=model, data=data, link_forces_inertial=W_f_L_total, joint_force_references=joint_torques, ) # Apply constraint forces to the corresponding links if W_f_L_constraints.shape[0] > 0: # Get the constraint map from the model's kinematic parameters constraint_map = model.kin_dyn_parameters.constraints if constraint_map is not None: # Stack the parent link indices for both sides of each constraint parent_indices_flat = jnp.concatenate( [constraint_map.parent_link_idxs_1, constraint_map.parent_link_idxs_2], ) # Flatten the constraint wrenches to match the flattened parent indices constraint_wrenches_flat = W_f_L_constraints.reshape(-1, 6) # Apply constraint wrenches using scatter_add for better performance W_f_L_total = W_f_L_total.at[parent_indices_flat].add( constraint_wrenches_flat ) # Store the link forces in a references object. references = js.references.JaxSimModelReferences.build( model=model, data=data, velocity_representation=data.velocity_representation, link_forces=W_f_L_total, ) # Compute forward dynamics. # # - Joint accelerations: s̈ ∈ ℝⁿ # - Base acceleration: v̇_WB ∈ ℝ⁶ # # Note that ABA returns the base acceleration in the velocity representation # stored in the `data` object. v̇_WB, s̈ = js.model.forward_dynamics_aba( model=model, data=data, joint_forces=joint_torques, link_forces=references.link_forces(model=model, data=data), ) return v̇_WB, s̈, contact_state @jax.jit @js.common.named_scope def system_position_dynamics( data: js.data.JaxSimModelData, baumgarte_quaternion_regularization: jtp.FloatLike = 1.0, ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]: r""" Compute the dynamics of the system position. Args: data: The data of the considered model. baumgarte_quaternion_regularization: The Baumgarte regularization coefficient for adjusting the quaternion norm. Returns: A tuple containing the derivative of the base position, the derivative of the base quaternion, and the derivative of the joint positions. Note: In inertial-fixed representation, the linear component of the base velocity is not the derivative of the base position. In fact, the base velocity is defined as: :math:`{} ^W v_{W, B} = \begin{bmatrix} {} ^W \dot{p}_B S({} ^W \omega_{W, B}) {} ^W p _B\\ {} ^W \omega_{W, B} \end{bmatrix}`. Where :math:`S(\cdot)` is the skew-symmetric matrix operator. """ ṡ = data.joint_velocities W_Q_B = data.base_orientation W_ω_WB = data.base_velocity[3:6] W_ṗ_B = data.base_velocity[0:3] + Skew.wedge(W_ω_WB) @ data.base_position W_Q̇_B = Quaternion.derivative( quaternion=W_Q_B, omega=W_ω_WB, omega_in_body_fixed=False, K=baumgarte_quaternion_regularization, ).squeeze() return W_ṗ_B, W_Q̇_B, ṡ @jax.jit @js.common.named_scope def system_dynamics( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_forces: jtp.Vector | None = None, joint_torques: jtp.Vector | None = None, baumgarte_quaternion_regularization: jtp.FloatLike = 1.0, ) -> dict[str, jtp.Vector]: """ Compute the dynamics of the system. Args: model: The model to consider. data: The data of the considered model. link_forces: The 6D forces to apply to the links expressed in the frame corresponding to the velocity representation of `data`. joint_torques: The joint torques acting on the joints. baumgarte_quaternion_regularization: The Baumgarte regularization coefficient used to adjust the norm of the quaternion (only used in integrators not operating on the SO(3) manifold). Returns: A dictionary containing the derivatives of the base position, the base quaternion, the joint positions, the base linear velocity, the base angular velocity, and the joint velocities. """ with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial): W_v̇_WB, s̈, contact_state_derivative = system_acceleration( model=model, data=data, joint_torques=joint_torques, link_forces=link_forces, ) W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics( data=data, baumgarte_quaternion_regularization=baumgarte_quaternion_regularization, ) return dict( base_position=W_ṗ_B, base_quaternion=W_Q̇_B, joint_positions=ṡ, base_linear_velocity=W_v̇_WB[0:3], base_angular_velocity=W_v̇_WB[3:6], joint_velocities=s̈, contact_state=contact_state_derivative, ) ================================================ FILE: src/jaxsim/api/references.py ================================================ from __future__ import annotations import functools import jax import jax.numpy as jnp import jax_dataclasses import jaxsim.api as js import jaxsim.typing as jtp from jaxsim import exceptions from jaxsim.utils.tracing import not_tracing from .common import VelRepr try: from typing import Self except ImportError: from typing_extensions import Self @jax_dataclasses.pytree_dataclass class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation): """ Class containing the references for a `JaxSimModel` object. Attributes: _link_forces: The link 6D forces in inertial-fixed representation. _joint_force_references: The joint force references. """ _link_forces: jtp.Matrix _joint_force_references: jtp.Vector @staticmethod def zero( model: js.model.JaxSimModel, data: js.data.JaxSimModelData | None = None, velocity_representation: VelRepr = VelRepr.Inertial, ) -> JaxSimModelReferences: """ Create a `JaxSimModelReferences` object with zero references. Args: model: The model for which to create the zero references. data: The data of the model, only needed if the velocity representation is not inertial-fixed. velocity_representation: The velocity representation to use. Returns: A `JaxSimModelReferences` object with zero state. """ return JaxSimModelReferences.build( model=model, data=data, velocity_representation=velocity_representation ) @staticmethod def build( model: js.model.JaxSimModel, joint_force_references: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, data: js.data.JaxSimModelData | None = None, velocity_representation: VelRepr | None = None, ) -> JaxSimModelReferences: """ Create a `JaxSimModelReferences` object with the given references. Args: model: The model for which to create the state. joint_force_references: The joint force references. link_forces: The link 6D forces in the desired representation. data: The data of the model, only needed if the velocity representation is not inertial-fixed. velocity_representation: The velocity representation to use. Returns: A `JaxSimModelReferences` object with the given references. """ # Create or adjust joint force references. joint_force_references = jnp.atleast_1d( jnp.array(joint_force_references, dtype=float).squeeze() if joint_force_references is not None else jnp.zeros(model.dofs()) ).astype(float) # Create or adjust link forces. f_L = jnp.atleast_2d( jnp.array(link_forces, dtype=float).squeeze() if link_forces is not None else jnp.zeros((model.number_of_links(), 6)) ).astype(float) # Select the velocity representation. velocity_representation = ( velocity_representation if velocity_representation is not None else getattr(data, "velocity_representation", VelRepr.Inertial) ) # Create a zero references object. references = JaxSimModelReferences( _link_forces=f_L, _joint_force_references=joint_force_references, velocity_representation=velocity_representation, ) # If the velocity representation is inertial-fixed, we can return # the references directly, as we store the link forces in this frame. if velocity_representation is VelRepr.Inertial: return references # Store the joint force references. references = references.set_joint_force_references( forces=joint_force_references, model=model, joint_names=model.joint_names(), ) # Apply the link forces. references = references.apply_link_forces( forces=f_L, model=model, data=data, link_names=model.link_names(), additive=False, ) return references def valid(self, model: js.model.JaxSimModel | None = None) -> bool: """ Check if the current references are valid for the given model. Args: model: The model to check against. Returns: `True` if the current references are valid for the given model, `False` otherwise. """ if model is None: return True shape = self._joint_force_references.shape expected_shape = (model.dofs(),) if shape != expected_shape: return False shape = self._link_forces.shape expected_shape = (model.number_of_links(), 6) if shape != expected_shape: return False return True # ================== # Extract quantities # ================== @js.common.named_scope @functools.partial(jax.jit, static_argnames=["link_names"]) def link_forces( self, model: js.model.JaxSimModel | None = None, data: js.data.JaxSimModelData | None = None, link_names: tuple[str, ...] | None = None, ) -> jtp.Matrix: """ Return the link forces expressed in the frame of the active representation. Args: model: The model to consider. data: The data of the considered model. link_names: The names of the links corresponding to the forces. Returns: If no model and no link names are provided, the link forces as a `(n_links,6)` matrix corresponding to the default link serialization of the original model used to build the actuation object. If a model is provided and no link names are provided, the link forces as a `(n_links,6)` matrix corresponding to the serialization of the provided model. If both a model and link names are provided, the link forces as a `(len(link_names),6)` matrix corresponding to the serialization of the passed link names vector. Note: The returned link forces are those passed as user inputs when integrating the dynamics of the model. They are summed with other forces related e.g. to the contact model and other kinematic constraints. """ W_f_L = self._link_forces # Return all link forces in inertial-fixed representation using the implicit # serialization. if model is None: if self.velocity_representation is not VelRepr.Inertial: msg = "Missing model to use a representation different from {}" raise ValueError(msg.format(VelRepr.Inertial.name)) if link_names is not None: raise ValueError("Link names cannot be provided without a model") return W_f_L # If we have the model, we can extract the link names, if not provided. link_idxs = ( js.link.names_to_idxs(link_names=link_names, model=model) if link_names is not None else jnp.arange(model.number_of_links()) ) # In inertial-fixed representation, we already have the link forces. if self.velocity_representation is VelRepr.Inertial: return W_f_L[link_idxs, :] if data is None: msg = "Missing model data to use a representation different from {}" raise ValueError(msg.format(VelRepr.Inertial.name)) if not_tracing(self._link_forces) and not data.valid(model=model): raise ValueError("The provided data is not valid for the model") # Helper function to convert a single 6D force to the active representation # considering as body the link (i.e. L_f_L and LW_f_L). def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix: return jax.vmap( lambda W_f_L, W_H_L: JaxSimModelReferences.inertial_to_other_representation( array=W_f_L, other_representation=self.velocity_representation, transform=W_H_L, is_force=True, ) )(W_f_L, W_H_L) # The f_L output is either L_f_L or LW_f_L, depending on the representation. W_H_L = data._link_transforms f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :]) return f_L def joint_force_references( self, model: js.model.JaxSimModel | None = None, joint_names: tuple[str, ...] | None = None, ) -> jtp.Vector: """ Return the joint force references. Args: model: The model to consider. joint_names: The names of the joints corresponding to the forces. Returns: If no model and no joint names are provided, the joint forces as a `(DoFs,)` vector corresponding to the default joint serialization of the original model used to build the actuation object. If a model is provided and no joint names are provided, the joint forces as a `(DoFs,)` vector corresponding to the serialization of the provided model. If both a model and joint names are provided, the joint forces as a `(len(joint_names),)` vector corresponding to the serialization of the passed joint names vector. Note: The returned joint forces are those passed as user inputs when integrating the dynamics of the model. They are summed with other joint forces related e.g. to the enforcement of other kinematic constraints. Keep also in mind that the presence of joint friction and other similar effects can make the actual joint forces different from the references. """ if model is None: if joint_names is not None: raise ValueError("Joint names cannot be provided without a model") return self._joint_force_references if not_tracing(self._joint_force_references) and not self.valid(model=model): msg = "The actuation object is not compatible with the provided model" raise ValueError(msg) joint_idxs = ( js.joint.names_to_idxs(joint_names=joint_names, model=model) if joint_names is not None else jnp.arange(model.number_of_joints()) ) return jnp.atleast_1d( self._joint_force_references[joint_idxs].squeeze() ).astype(float) # ================ # Store quantities # ================ @js.common.named_scope @functools.partial(jax.jit, static_argnames=["joint_names"]) def set_joint_force_references( self, forces: jtp.VectorLike, model: js.model.JaxSimModel | None = None, joint_names: tuple[str, ...] | None = None, ) -> Self: """ Set the joint force references. Args: forces: The joint force references. model: The model to consider, only needed if a joint serialization different from the implicit one is used. joint_names: The names of the joints corresponding to the forces. Returns: A new `JaxSimModelReferences` object with the given joint force references. """ forces = jnp.atleast_1d(jnp.array(forces, dtype=float).squeeze()) def replace(forces: jtp.Vector) -> JaxSimModelReferences: return self.replace( validate=True, _joint_force_references=jnp.atleast_1d(forces.squeeze()).astype(float), ) if model is None: return replace(forces=forces) if not_tracing(forces) and not self.valid(model=model): msg = "The references object is not compatible with the provided model" raise ValueError(msg) joint_idxs = ( js.joint.names_to_idxs(joint_names=joint_names, model=model) if joint_names is not None else jnp.arange(model.number_of_joints()) ) return replace(forces=self._joint_force_references.at[joint_idxs].set(forces)) @js.common.named_scope @functools.partial(jax.jit, static_argnames=["link_names", "additive"]) def apply_link_forces( self, forces: jtp.MatrixLike, model: js.model.JaxSimModel | None = None, data: js.data.JaxSimModelData | None = None, link_names: tuple[str, ...] | str | None = None, additive: bool = False, ) -> Self: """ Apply the link forces. Args: forces: The link 6D forces in the active representation. model: The model to consider, only needed if a link serialization different from the implicit one is used. data: The data of the considered model, only needed if the velocity representation is not inertial-fixed. link_names: The names of the links corresponding to the forces. additive: Whether to add the forces to the existing ones instead of replacing them. Returns: A new `JaxSimModelReferences` object with the given link forces. Note: The link forces must be expressed in the active representation. Then, we always convert and store forces in inertial-fixed representation. """ f_L = jnp.atleast_2d(forces).astype(float) # Helper function to replace the link forces. def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: return self.replace( validate=True, _link_forces=jnp.atleast_2d(forces.squeeze()).astype(float), ) # In this case, we allow only to set the inertial 6D forces to all links # using the implicit link serialization. if model is None: if self.velocity_representation is not VelRepr.Inertial: msg = "Missing model to use a representation different from {}" raise ValueError(msg.format(VelRepr.Inertial.name)) if link_names is not None: raise ValueError("Link names cannot be provided without a model") W_f_L = f_L W_f0_L = jnp.zeros_like(W_f_L) if not additive else self._link_forces return replace(forces=W_f0_L + W_f_L) if link_names is not None and len(link_names) != f_L.shape[0]: msg = "The number of link names ({}) must match the number of forces ({})" raise ValueError(msg.format(len(link_names), f_L.shape[0])) # Extract the link indices. link_idxs = ( js.link.names_to_idxs(link_names=link_names, model=model) if link_names is not None else jnp.arange(model.number_of_links()) ) # Compute the bias depending on whether we either set or add the link forces. W_f0_L = ( jnp.zeros_like(f_L) if not additive else self._link_forces[link_idxs, :] ) # If inertial-fixed representation, we can directly store the link forces. if self.velocity_representation is VelRepr.Inertial: W_f_L = f_L return replace( forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L) ) if data is None: msg = "Missing model data to use a representation different from {}" raise ValueError(msg.format(VelRepr.Inertial.name)) if not_tracing(forces) and not data.valid(model=model): raise ValueError("The provided data is not valid for the model") W_H_L = data._link_transforms # Convert a single 6D force to the inertial representation # considering as body the link (i.e. L_f_L and LW_f_L). # The f_L input is either L_f_L or LW_f_L, depending on the representation. W_f_L = JaxSimModelReferences.other_representation_to_inertial( array=f_L, other_representation=self.velocity_representation, transform=W_H_L[link_idxs] if model.number_of_links() > 1 else W_H_L, is_force=True, ) return replace(forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L)) def apply_frame_forces( self, forces: jtp.MatrixLike, model: js.model.JaxSimModel, data: js.data.JaxSimModelData, frame_names: tuple[str, ...] | str | None = None, additive: bool = False, ) -> Self: """ Apply the frame forces. Args: forces: The frame 6D forces in the active representation. model: The model to consider, only needed if a frame serialization different from the implicit one is used. data: The data of the considered model, only needed if the velocity representation is not inertial-fixed. frame_names: The names of the frames corresponding to the forces. additive: Whether to add the forces to the existing ones instead of replacing them. Returns: A new `JaxSimModelReferences` object with the given frame forces. Note: The frame forces must be expressed in the active representation. Then, we always convert and store forces in inertial-fixed representation. """ f_F = jnp.atleast_2d(forces).astype(float) if len(frame_names) != f_F.shape[0]: msg = "The number of frame names ({}) must match the number of forces ({})" raise ValueError(msg.format(len(frame_names), f_F.shape[0])) # Extract the frame indices. frame_idxs = ( js.frame.names_to_idxs(frame_names=frame_names, model=model) if frame_names is not None else jnp.arange(len(model.frame_names())) ) parent_link_idxs = jnp.array(model.kin_dyn_parameters.frame_parameters.body)[ frame_idxs - model.number_of_links() ] exceptions.raise_value_error_if( condition=~data.valid(model=model), msg="The provided data is not valid for the model", ) W_H_Fi = jax.vmap( lambda frame_idx: js.frame.transform( model=model, data=data, frame_index=frame_idx ) )(frame_idxs) # Helper function to convert a single 6D force to the inertial representation # considering as body the frame (i.e. L_f_F and LW_f_F). def to_inertial(f_F: jtp.MatrixLike, W_H_F: jtp.MatrixLike) -> jtp.Matrix: return JaxSimModelReferences.other_representation_to_inertial( array=f_F, other_representation=self.velocity_representation, transform=W_H_F, is_force=True, ) match self.velocity_representation: case VelRepr.Inertial: W_f_F = f_F case VelRepr.Body | VelRepr.Mixed: W_f_F = jax.vmap(to_inertial)(f_F, W_H_Fi) case _: raise ValueError("Invalid velocity representation.") # Sum the forces on the parent links. mask = parent_link_idxs[:, jnp.newaxis] == jnp.arange(model.number_of_links()) W_f_L = mask.T @ W_f_F with self.switch_velocity_representation( velocity_representation=VelRepr.Inertial ): references = self.apply_link_forces( model=model, data=data, forces=W_f_L, additive=additive, ) with references.switch_velocity_representation( velocity_representation=self.velocity_representation ): return references ================================================ FILE: src/jaxsim/exceptions.py ================================================ import os import jax def raise_if( condition: bool | jax.Array, exception: type, msg: str, *args, **kwargs ) -> None: """ Raise a host-side exception if a condition is met. Useful in jit-compiled functions. Args: condition: The boolean condition of the evaluated expression that triggers the exception during runtime. exception: The type of exception to raise. msg: The message to display when the exception is raised. The message can be a format string (fmt), whose fields are filled with the args and kwargs. *args: The arguments to fill the format string. **kwargs: The keyword arguments to fill the format string """ # Disable host callback if running on unsupported hardware or if the user # explicitly disabled it. if jax.devices()[0].platform in {"tpu", "METAL"} or not os.environ.get( "JAXSIM_ENABLE_EXCEPTIONS", 0 ): return # Check early that the format string is well-formed. try: _ = msg.format(*args, **kwargs) except Exception as e: msg = "Error in formatting exception message with args={} and kwargs={}" raise ValueError(msg.format(args, kwargs)) from e def _raise_exception(condition: bool, *args, **kwargs) -> None: """The function called by the JAX callback.""" if condition: raise exception(msg.format(*args, **kwargs)) def _callback(args, kwargs) -> None: """The function that calls the JAX callback, executed only when needed.""" jax.debug.callback(_raise_exception, condition, *args, **kwargs) # Since running a callable on the host is expensive, we prevent its execution # if the condition is False with a low-level conditional expression. def _run_callback_only_if_condition_is_true(*args, **kwargs) -> None: return jax.lax.cond( condition, _callback, lambda args, kwargs: None, args, kwargs, ) return _run_callback_only_if_condition_is_true(*args, **kwargs) def raise_runtime_error_if( condition: bool | jax.Array, msg: str, *args, **kwargs ) -> None: """ Raise a RuntimeError if a condition is met. Useful in jit-compiled functions. """ return raise_if(condition, RuntimeError, msg, *args, **kwargs) def raise_value_error_if( condition: bool | jax.Array, msg: str, *args, **kwargs ) -> None: """ Raise a ValueError if a condition is met. Useful in jit-compiled functions. """ return raise_if(condition, ValueError, msg, *args, **kwargs) ================================================ FILE: src/jaxsim/logging.py ================================================ import enum import inspect import logging import os import warnings import coloredlogs class JaxSimWarning(UserWarning): pass _original_showwarning = warnings.showwarning _jaxsim_root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) def pretty_jaxsim_warning(message, category, filename, lineno, file=None, line=None): try: caller_frame = inspect.stack()[2] caller_file = caller_frame.filename except Exception: caller_file = filename if caller_file.startswith(_jaxsim_root_dir): print(f"\033[93m⚠️ {category.__name__}:\033[0m {message}") print(f" → {filename}:{lineno}") else: _original_showwarning(message, category, filename, lineno, file, line) # Register filter & formatter only for JaxSimWarning # and configure it to show each warning only once warnings.showwarning = pretty_jaxsim_warning warnings.filterwarnings("once") # Utility function to issue a JaxSim warning def jaxsim_warn(msg): warnings.warn(msg, category=JaxSimWarning, stacklevel=2) LOGGER_NAME = "jaxsim" class LoggingLevel(enum.IntEnum): NOTSET = logging.NOTSET DEBUG = logging.DEBUG INFO = logging.INFO WARNING = logging.WARNING ERROR = logging.ERROR CRITICAL = logging.CRITICAL def _logger() -> logging.Logger: return logging.getLogger(name=LOGGER_NAME) def set_logging_level(level: int | LoggingLevel = LoggingLevel.WARNING): if isinstance(level, int): level = LoggingLevel(level) _logger().setLevel(level=level.value) def get_logging_level() -> LoggingLevel: level = _logger().getEffectiveLevel() return LoggingLevel(level) def configure(level: LoggingLevel = LoggingLevel.WARNING) -> None: info("Configuring the 'jaxsim' logger") handler = logging.StreamHandler() fmt = "%(name)s[%(process)d] %(levelname)s %(message)s" handler.setFormatter(fmt=coloredlogs.ColoredFormatter(fmt=fmt)) _logger().addHandler(hdlr=handler) # Do not propagate the messages to handlers of parent loggers # (preventing duplicate logging) _logger().propagate = False set_logging_level(level=level) def debug(msg: str = "") -> None: _logger().debug(msg=msg) def info(msg: str = "") -> None: _logger().info(msg=msg) def warning(msg: str = "") -> None: _logger().warning(msg=msg) def error(msg: str = "") -> None: _logger().error(msg=msg) def critical(msg: str = "") -> None: _logger().critical(msg=msg) def exception(msg: str = "") -> None: _logger().exception(msg=msg) ================================================ FILE: src/jaxsim/math/__init__.py ================================================ from .adjoint import Adjoint from .cross import Cross from .inertia import Inertia from .quaternion import Quaternion from .rotation import Rotation from .skew import Skew from .transform import Transform from .utils import safe_norm from .joint_model import JointModel, supported_joint_motion # isort:skip # Define the default standard gravity constant. STANDARD_GRAVITY = 9.81 ================================================ FILE: src/jaxsim/math/adjoint.py ================================================ import jax.numpy as jnp import jaxlie import jaxsim.typing as jtp from .skew import Skew class Adjoint: """ A utility class for adjoint matrix operations. """ @staticmethod def from_quaternion_and_translation( quaternion: jtp.Vector | None = None, translation: jtp.Vector | None = None, inverse: bool = False, normalize_quaternion: bool = False, ) -> jtp.Matrix: """ Create an adjoint matrix from a quaternion and a translation. Args: quaternion: A quaternion vector (4D) representing orientation. translation: A translation vector (3D). inverse: Whether to compute the inverse adjoint. normalize_quaternion: Whether to normalize the quaternion before creating the adjoint. Returns: The adjoint matrix. """ quaternion = quaternion if quaternion is not None else jnp.array([1.0, 0, 0, 0]) translation = translation if translation is not None else jnp.zeros(3) assert quaternion.size == 4 assert translation.size == 3 Q_sixd = jaxlie.SO3(wxyz=quaternion) Q_sixd = Q_sixd if not normalize_quaternion else Q_sixd.normalize() return Adjoint.from_rotation_and_translation( rotation=Q_sixd.as_matrix(), translation=translation, inverse=inverse ) @staticmethod def from_transform(transform: jtp.MatrixLike, inverse: bool = False) -> jtp.Matrix: """ Create an adjoint matrix from a transformation matrix. Args: transform: A 4x4 transformation matrix. inverse: Whether to compute the inverse adjoint. Returns: The 6x6 adjoint matrix. """ A_H_B = transform return ( jaxlie.SE3.from_matrix(matrix=A_H_B).adjoint() if not inverse else jaxlie.SE3.from_matrix(matrix=A_H_B).inverse().adjoint() ) @staticmethod def from_rotation_and_translation( rotation: jtp.Matrix | None = None, translation: jtp.Vector | None = None, inverse: bool = False, ) -> jtp.Matrix: """ Create an adjoint matrix from a rotation matrix and a translation vector. Args: rotation: A 3x3 rotation matrix. translation: A translation vector (3D). inverse: Whether to compute the inverse adjoint. Default is False. Returns: The adjoint matrix. """ rotation = rotation if rotation is not None else jnp.eye(3) translation = translation if translation is not None else jnp.zeros(3) assert rotation.shape == (3, 3) assert translation.size == 3 A_R_B = rotation.squeeze() A_o_B = translation.squeeze() if not inverse: X = A_X_B = jnp.vstack( # noqa: F841 [ jnp.block([A_R_B, Skew.wedge(A_o_B) @ A_R_B]), jnp.block([jnp.zeros(shape=(3, 3)), A_R_B]), ] ) else: X = B_X_A = jnp.vstack( # noqa: F841 [ jnp.block([A_R_B.T, -A_R_B.T @ Skew.wedge(A_o_B)]), jnp.block([jnp.zeros(shape=(3, 3)), A_R_B.T]), ] ) return X @staticmethod def to_transform(adjoint: jtp.Matrix) -> jtp.Matrix: """ Convert an adjoint matrix to a transformation matrix. Args: adjoint: The adjoint matrix (6x6). Returns: The transformation matrix (4x4). """ X = adjoint.squeeze() assert X.shape == (6, 6) R = X[0:3, 0:3] o_x_R = X[0:3, 3:6] H = jnp.vstack( [ jnp.block([R, Skew.vee(matrix=o_x_R @ R.T)]), jnp.array([0, 0, 0, 1]), ] ) return H @staticmethod def inverse(adjoint: jtp.Matrix) -> jtp.Matrix: """ Compute the inverse of an adjoint matrix. Args: adjoint: The adjoint matrix. Returns: The inverse adjoint matrix. """ A_X_B = adjoint.reshape(-1, 6, 6) A_R_B_T = jnp.swapaxes(A_X_B[..., 0:3, 0:3], -2, -1) A_T_B = A_X_B[..., 0:3, 3:6] return jnp.concatenate( [ jnp.concatenate( [A_R_B_T, -A_R_B_T @ A_T_B @ A_R_B_T], axis=-1, ), jnp.concatenate([jnp.zeros_like(A_R_B_T), A_R_B_T], axis=-1), ], axis=-2, ).reshape(adjoint.shape) ================================================ FILE: src/jaxsim/math/cross.py ================================================ import jax.numpy as jnp import jaxsim.typing as jtp from .skew import Skew class Cross: """ A utility class for cross product matrix operations. """ @staticmethod def vx(velocity_sixd: jtp.Vector) -> jtp.Matrix: """ Compute the cross product matrix for 6D velocities. Args: velocity_sixd: A 6D velocity vector [v, ω]. Returns: The cross product matrix (6x6). Raises: ValueError: If the input vector does not have a size of 6. """ velocity_sixd = velocity_sixd.reshape(-1, 6) v, ω = jnp.split(velocity_sixd, 2, axis=-1) v_cross = jnp.concatenate( [ jnp.concatenate( [Skew.wedge(ω), jnp.zeros((ω.shape[0], 3, 3)).squeeze()], axis=-2 ), jnp.concatenate([Skew.wedge(v), Skew.wedge(ω)], axis=-2), ], axis=-1, ) return v_cross @staticmethod def vx_star(velocity_sixd: jtp.Vector) -> jtp.Matrix: """ Compute the negative transpose of the cross product matrix for 6D velocities. Args: velocity_sixd: A 6D velocity vector [v, ω]. Returns: The negative transpose of the cross product matrix (6x6). Raises: ValueError: If the input vector does not have a size of 6. """ v_cross_star = -Cross.vx(velocity_sixd).T return v_cross_star ================================================ FILE: src/jaxsim/math/inertia.py ================================================ import jax.numpy as jnp import jaxsim.typing as jtp from .skew import Skew class Inertia: """ A utility class for inertia matrix operations. """ @staticmethod def to_sixd(mass: jtp.Float, com: jtp.Vector, I: jtp.Matrix) -> jtp.Matrix: """ Convert mass, center of mass, and inertia matrix to a 6x6 inertia matrix. Args: mass: The mass of the body. com: The center of mass position (3D). I: The 3x3 inertia matrix. Returns: The 6x6 inertia matrix. Raises: ValueError: If the shape of the inertia matrix I is not (3, 3). """ if I.shape != (3, 3): raise ValueError(I, I.shape) c = Skew.wedge(vector=com) M = jnp.vstack( [ jnp.block([mass * jnp.eye(3), mass * c.T]), jnp.block([mass * c, I + mass * c @ c.T]), ] ) return M @staticmethod def to_params(M: jtp.Matrix) -> tuple[jtp.Float, jtp.Vector, jtp.Matrix]: """ Convert a 6x6 inertia matrix to mass, center of mass, and inertia matrix. Args: M: The 6x6 inertia matrix. Returns: A tuple containing mass, center of mass (3D), and inertia matrix (3x3). Raises: ValueError: If the input matrix M has an unexpected shape. """ m = jnp.diag(M[0:3, 0:3]).sum() / 3 mC = M[3:6, 0:3] c = Skew.vee(mC) / m I = M[3:6, 3:6] - (mC @ mC.T / m) return m, c, I ================================================ FILE: src/jaxsim/math/joint_model.py ================================================ from __future__ import annotations import jax import jax.numpy as jnp import jax_dataclasses import jaxlie from jax_dataclasses import Static import jaxsim.typing as jtp from jaxsim.math import Rotation from jaxsim.parsers.descriptions import JointGenericAxis, JointType, ModelDescription from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass @jax_dataclasses.pytree_dataclass class JointModel(JaxsimDataclass): """ Class describing the joint kinematics of a robot model. Attributes: λ_H_pre: The homogeneous transformation between the parent link and the predecessor frame of each joint. suc_H_i: The homogeneous transformation between the successor frame and the child link of each joint. joint_dofs: The number of DoFs of each joint. joint_names: The names of each joint. joint_types: The types of each joint. Note: Due to the presence of the static attributes, this class needs to be created already in a vectorized form. In other words, it cannot be created using vmap. """ λ_H_pre: jtp.Array suc_H_i: jtp.Array joint_dofs: Static[tuple[int, ...]] joint_names: Static[tuple[str, ...]] joint_types: Static[tuple[int, ...]] joint_axis: Static[tuple[JointGenericAxis, ...]] @staticmethod def build(description: ModelDescription) -> JointModel: """ Build the joint model of a model description. Args: description: The model description to consider. Returns: The joint model of the considered model description. """ # The link index is equal to its body index: [0, number_of_bodies - 1]. ordered_links = sorted( list(description.links_dict.values()), key=lambda l: l.index, ) # Note: the joint index is equal to its child link index, therefore it # starts from 1. ordered_joints = sorted( list(description.joints_dict.values()), key=lambda j: j.index, ) # Allocate the parent-to-predecessor and successor-to-child transforms. λ_H_pre = jnp.zeros(shape=(1 + len(ordered_joints), 4, 4), dtype=float) suc_H_i = jnp.zeros(shape=(1 + len(ordered_joints), 4, 4), dtype=float) # Initialize an identical parent-to-predecessor transform for the joint # between the world frame W and the base link B. λ_H_pre = λ_H_pre.at[0].set(jnp.eye(4)) # Initialize the successor-to-child transform of the joint between the # world frame W and the base link B. # We store here the optional transform between the root frame of the model # and the base link frame (this is needed only if the pose of the link frame # w.r.t. the implicit __model__ SDF frame is not the identity). suc_H_i = suc_H_i.at[0].set(ordered_links[0].pose) # Create the object to compute forward kinematics. fk = KinematicGraphTransforms(graph=description) # Compute the parent-to-predecessor and successor-to-child transforms for # each joint belonging to the model. # Note that the joint indices starts from i=1 given our joint model, # therefore the entries at index 0 are not updated. for joint in ordered_joints: λ_H_pre = λ_H_pre.at[joint.index].set( fk.relative_transform(relative_to=joint.parent.name, name=joint.name) ) suc_H_i = suc_H_i.at[joint.index].set( fk.relative_transform(relative_to=joint.name, name=joint.child.name) ) # Define the DoFs of the base link. base_dofs = 0 if description.fixed_base else 6 # We always add a dummy fixed joint between world and base. # TODO: Port floating-base support also at this level, not only in RBDAs. return JointModel( λ_H_pre=λ_H_pre, suc_H_i=suc_H_i, # Static attributes joint_dofs=tuple([base_dofs] + [1 for _ in ordered_joints]), joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]), joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]), joint_axis=tuple(JointGenericAxis(axis=j.axis) for j in ordered_joints), ) def parent_H_predecessor(self, joint_index: jtp.IntLike) -> jtp.Matrix: r""" Return the homogeneous transformation between the parent link and the predecessor frame of a joint. Args: joint_index: The index of the joint. Returns: The homogeneous transformation :math:`{}^{\lambda(i)} \mathbf{H}_{\text{pre}(i)}`. """ return self.λ_H_pre[joint_index] def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix: r""" Return the homogeneous transformation between the successor frame and the child link of a joint. Args: joint_index: The index of the joint. Returns: The homogeneous transformation :math:`{}^{\text{suc}(i)} \mathbf{H}_i`. """ return self.suc_H_i[joint_index] @jax.jit def supported_joint_motion( joint_types: jtp.Array, joint_positions: jtp.Matrix, joint_axes: jtp.Matrix ) -> jtp.Matrix: """ Compute the transforms of the joints. Args: joint_types: The types of the joints. joint_positions: The positions of the joints. joint_axes: The axes of the joints. Returns: The transforms of the joints. """ # Prepare the joint position s = jnp.array(joint_positions).astype(float) def compute_F() -> tuple[jtp.Matrix, jtp.Array]: return jaxlie.SE3.identity() def compute_R() -> tuple[jtp.Matrix, jtp.Array]: # Get the additional argument specifying the joint axis. # This is a metadata required by only some joint types. axis = jnp.array(joint_axes).astype(float).squeeze() pre_H_suc = jaxlie.SE3.from_matrix( matrix=jnp.eye(4).at[:3, :3].set(Rotation.from_axis_angle(vector=s * axis)) ) return pre_H_suc def compute_P() -> tuple[jtp.Matrix, jtp.Array]: # Get the additional argument specifying the joint axis. # This is a metadata required by only some joint types. axis = jnp.array(joint_axes).astype(float).squeeze() pre_H_suc = jaxlie.SE3.from_rotation_and_translation( rotation=jaxlie.SO3.identity(), translation=jnp.array(s * axis), ) return pre_H_suc return jax.lax.switch( index=joint_types, branches=( compute_F, # JointType.Fixed compute_R, # JointType.Revolute compute_P, # JointType.Prismatic ), ).as_matrix() ================================================ FILE: src/jaxsim/math/quaternion.py ================================================ import jax.lax import jax.numpy as jnp import jaxlie import jaxsim.typing as jtp from .utils import safe_norm class Quaternion: """ A utility class for quaternion operations. """ @staticmethod def to_xyzw(wxyz: jtp.Vector) -> jtp.Vector: """ Convert a quaternion from WXYZ to XYZW representation. Args: wxyz: Quaternion in WXYZ representation. Returns: Quaternion in XYZW representation. """ return wxyz.squeeze()[jnp.array([1, 2, 3, 0])] @staticmethod def to_wxyz(xyzw: jtp.Vector) -> jtp.Vector: """ Convert a quaternion from XYZW to WXYZ representation. Args: xyzw: Quaternion in XYZW representation. Returns: Quaternion in WXYZ representation. """ return xyzw.squeeze()[jnp.array([3, 0, 1, 2])] @staticmethod def to_dcm(quaternion: jtp.Vector) -> jtp.Matrix: """ Convert a quaternion to a direction cosine matrix (DCM). Args: quaternion: Quaternion in XYZW representation. Returns: The Direction cosine matrix (DCM). """ return jaxlie.SO3(wxyz=quaternion).as_matrix() @staticmethod def from_dcm(dcm: jtp.Matrix) -> jtp.Vector: """ Convert a direction cosine matrix (DCM) to a quaternion. Args: dcm: Direction cosine matrix (DCM). Returns: Quaternion in WXYZ representation. """ return jaxlie.SO3.from_matrix(matrix=dcm).wxyz @staticmethod def derivative( quaternion: jtp.Vector, omega: jtp.Vector, omega_in_body_fixed: bool = False, K: float = 0.1, ) -> jtp.Vector: """ Compute the derivative of a quaternion given angular velocity. Args: quaternion: Quaternion in XYZW representation. omega: Angular velocity vector. omega_in_body_fixed (bool): Whether the angular velocity is in the body-fixed frame. K (float): A scaling factor. Returns: The derivative of the quaternion. """ ω = omega.squeeze() quaternion = quaternion.squeeze() def Q_body(q: jtp.Vector) -> jtp.Matrix: qw, qx, qy, qz = q return jnp.array( [ [qw, -qx, -qy, -qz], [qx, qw, -qz, qy], [qy, qz, qw, -qx], [qz, -qy, qx, qw], ] ) def Q_inertial(q: jtp.Vector) -> jtp.Matrix: qw, qx, qy, qz = q return jnp.array( [ [qw, -qx, -qy, -qz], [qx, qw, qz, -qy], [qy, -qz, qw, qx], [qz, qy, -qx, qw], ] ) Q = jax.lax.cond( pred=omega_in_body_fixed, true_fun=Q_body, false_fun=Q_inertial, operand=quaternion, ) norm_ω = safe_norm(ω) qd = 0.5 * ( Q @ jnp.hstack( [ K * norm_ω * (1 - safe_norm(quaternion)), ω, ] ) ) return jnp.vstack(qd) @staticmethod def integration( quaternion: jtp.VectorLike, dt: jtp.FloatLike, omega: jtp.VectorLike, omega_in_body_fixed: jtp.BoolLike = False, ) -> jtp.Vector: """ Integrate a quaternion in SO(3) given an angular velocity. Args: quaternion: The quaternion to integrate. dt: The time step. omega: The angular velocity vector. omega_in_body_fixed: Whether the angular velocity is in body-fixed representation as opposed to the default inertial-fixed representation. Returns: The integrated quaternion. """ ω_AB = jnp.array(omega).squeeze().astype(float) A_Q_B = jnp.array(quaternion).squeeze().astype(float) # Build the initial SO(3) quaternion. W_Q_B_t0 = jaxlie.SO3(wxyz=A_Q_B) # Integrate the quaternion on the manifold. W_Q_B_tf = jax.lax.select( pred=omega_in_body_fixed, on_true=(W_Q_B_t0 @ jaxlie.SO3.exp(tangent=dt * ω_AB)).wxyz, on_false=(jaxlie.SO3.exp(tangent=dt * ω_AB) @ W_Q_B_t0).wxyz, ) return W_Q_B_tf ================================================ FILE: src/jaxsim/math/rotation.py ================================================ import jax.numpy as jnp import jaxlie import jaxsim.typing as jtp from .skew import Skew from .utils import safe_norm class Rotation: """ A utility class for rotation matrix operations. """ @staticmethod def x(theta: jtp.Float) -> jtp.Matrix: """ Generate a 3D rotation matrix around the X-axis. Args: theta: Rotation angle in radians. Returns: The 3D rotation matrix. """ return jaxlie.SO3.from_x_radians(theta=theta).as_matrix() @staticmethod def y(theta: jtp.Float) -> jtp.Matrix: """ Generate a 3D rotation matrix around the Y-axis. Args: theta: Rotation angle in radians. Returns: The 3D rotation matrix. """ return jaxlie.SO3.from_y_radians(theta=theta).as_matrix() @staticmethod def z(theta: jtp.Float) -> jtp.Matrix: """ Generate a 3D rotation matrix around the Z-axis. Args: theta: Rotation angle in radians. Returns: The 3D rotation matrix. """ return jaxlie.SO3.from_z_radians(theta=theta).as_matrix() @staticmethod def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix: """ Generate a 3D rotation matrix from an axis-angle representation. Args: vector: Axis-angle representation or the rotation as a 3D vector. Returns: The SO(3) rotation matrix. """ vector = vector.squeeze() theta = safe_norm(vector) s = jnp.sin(theta) c = jnp.cos(theta) c1 = 2 * jnp.sin(theta / 2.0) ** 2 safe_theta = jnp.where(theta == 0, 1.0, theta) u = vector / safe_theta u = jnp.vstack(u.squeeze()) R = c * jnp.eye(3) - s * Skew.wedge(u) + c1 * u @ u.T return R.transpose() @staticmethod def log_vee(R: jnp.ndarray) -> jtp.Vector: """ Compute the logarithm map of an SO(3) rotation matrix. Args: R: The SO(3) rotation matrix. Returns: The corresponding so(3) tangent vector. """ return jaxlie.SO3.from_matrix(R).log() ================================================ FILE: src/jaxsim/math/skew.py ================================================ import jax.numpy as jnp import jaxsim.typing as jtp class Skew: """ A utility class for skew-symmetric matrix operations. """ @staticmethod def wedge(vector: jtp.Vector) -> jtp.Matrix: """ Compute the skew-symmetric matrix (wedge operator) of a 3D vector. Args: vector: A 3D vector. Returns: The skew-symmetric matrix corresponding to the input vector. """ vector = vector.reshape(-1, 3) x, y, z = jnp.split(vector, 3, axis=-1) skew = jnp.stack( [ jnp.concatenate([jnp.zeros_like(x), -z, y], axis=-1), jnp.concatenate([z, jnp.zeros_like(x), -x], axis=-1), jnp.concatenate([-y, x, jnp.zeros_like(x)], axis=-1), ], axis=-2, ).squeeze() return skew @staticmethod def vee(matrix: jtp.Matrix) -> jtp.Vector: """ Extract the 3D vector from a skew-symmetric matrix (vee operator). Args: matrix: A 3x3 skew-symmetric matrix. Returns: The 3D vector extracted from the input matrix. """ vector = 0.5 * jnp.vstack( [ matrix[2, 1] - matrix[1, 2], matrix[0, 2] - matrix[2, 0], matrix[1, 0] - matrix[0, 1], ] ) return vector ================================================ FILE: src/jaxsim/math/transform.py ================================================ import jax.numpy as jnp import jaxlie import jaxsim.typing as jtp class Transform: """ A utility class for transformation matrix operations. """ @staticmethod def from_quaternion_and_translation( quaternion: jtp.VectorLike | None = None, translation: jtp.VectorLike | None = None, inverse: jtp.BoolLike = False, normalize_quaternion: jtp.BoolLike = False, ) -> jtp.Matrix: """ Create a transformation matrix from a quaternion and a translation. Args: quaternion: A 4D vector representing a SO(3) orientation. translation: A 3D vector representing a translation. inverse: Whether to compute the inverse transformation. normalize_quaternion: Whether to normalize the quaternion before creating the transformation. Returns: The 4x4 transformation matrix representing the SE(3) transformation. """ quaternion = quaternion if quaternion is not None else jnp.array([1.0, 0, 0, 0]) translation = translation if translation is not None else jnp.zeros(3) W_Q_B = jnp.array(quaternion).astype(float) W_p_B = jnp.array(translation).astype(float) assert W_p_B.shape[-1] == 3 assert W_Q_B.shape[-1] == 4 A_R_B = jaxlie.SO3(wxyz=W_Q_B) A_R_B = A_R_B if not normalize_quaternion else A_R_B.normalize() A_H_B = jaxlie.SE3.from_rotation_and_translation( rotation=A_R_B, translation=W_p_B ) return A_H_B.as_matrix() if not inverse else A_H_B.inverse().as_matrix() @staticmethod def from_rotation_and_translation( rotation: jtp.MatrixLike | None = None, translation: jtp.VectorLike | None = None, inverse: jtp.BoolLike = False, ) -> jtp.Matrix: """ Create a transformation matrix from a rotation matrix and a translation vector. Args: rotation: A 3x3 rotation matrix representing a SO(3) orientation. translation: A 3D vector representing a translation. inverse: Whether to compute the inverse transformation. Returns: The 4x4 transformation matrix representing the SE(3) transformation. """ rotation = rotation if rotation is not None else jnp.eye(3) translation = translation if translation is not None else jnp.zeros(3) A_R_B = jnp.array(rotation).astype(float) W_p_B = jnp.array(translation).astype(float) assert W_p_B.size == 3 assert A_R_B.shape == (3, 3) A_H_B = jaxlie.SE3.from_rotation_and_translation( rotation=jaxlie.SO3.from_matrix(A_R_B), translation=W_p_B ) return A_H_B.as_matrix() if not inverse else A_H_B.inverse().as_matrix() @staticmethod def inverse(transform: jtp.MatrixLike) -> jtp.Matrix: """ Compute the inverse transformation matrix. Args: transform: A 4x4 transformation matrix. Returns: The 4x4 inverse transformation matrix. """ return jaxlie.SE3.from_matrix(matrix=transform).inverse().as_matrix() ================================================ FILE: src/jaxsim/math/utils.py ================================================ import jax import jax.numpy as jnp import jaxsim.typing as jtp def _make_safe_norm(axis, keepdims): @jax.custom_jvp def _safe_norm(array: jtp.ArrayLike) -> jtp.Array: """ Compute an array norm handling NaNs and making sure that it is safe to get the gradient. Args: array: The array for which to compute the norm. Returns: The norm of the array with handling for zero arrays to avoid NaNs. """ # Compute the norm of the array along the specified axis. return jnp.linalg.norm(array, axis=axis, keepdims=keepdims) @_safe_norm.defjvp def _safe_norm_jvp(primals, tangents): (x,), (x_dot,) = primals, tangents # Check if the entire array is composed of zeros. is_zero = jnp.allclose(x, 0) # Replace zeros with an array of ones temporarily to avoid division by zero. # This ensures the computation of norm does not produce NaNs or Infs. array = jnp.where(is_zero, jnp.ones_like(x), x) # Compute the norm of the array along the specified axis. norm = jnp.linalg.norm(array, axis=axis, keepdims=keepdims) dot = jnp.sum(array * x_dot, axis=axis, keepdims=keepdims) tangent = jnp.where(is_zero, 0.0, dot / norm) return jnp.where(is_zero, 0.0, norm), tangent return _safe_norm def safe_norm(array: jtp.ArrayLike, *, axis=None, keepdims: bool = False) -> jtp.Array: """ Compute an array norm handling NaNs and making sure that it is safe to get the gradient. Args: array: The array for which to compute the norm. axis: The axis for which to compute the norm. keepdims: Whether to keep the dimensions of the input Returns: The norm of the array with handling for zero arrays to avoid NaNs. """ return _make_safe_norm(axis, keepdims)(array) ================================================ FILE: src/jaxsim/mujoco/__init__.py ================================================ from .loaders import ModelToMjcf, RodModelToMjcf, SdfToMjcf, UrdfToMjcf from .model import MujocoModelHelper from .utils import MujocoCamera, mujoco_data_from_jaxsim from .visualizer import MujocoVideoRecorder, MujocoVisualizer ================================================ FILE: src/jaxsim/mujoco/__main__.py ================================================ import argparse import pathlib import sys import time import numpy as np from . import ModelToMjcf, MujocoModelHelper, MujocoVisualizer if __name__ == "__main__": parser = argparse.ArgumentParser( prog="jaxsim.mujoco", description="Process URDF and SDF files for Mujoco usage.", ) parser.add_argument( "-d", "--description", required=True, metavar="INPUT_FILE", type=pathlib.Path, help="Path to the URDF or SDF file.", ) parser.add_argument( "-m", "--model-name", metavar="NAME", type=str, default=None, help="The target model of a SDF description if multiple models exists.", ) parser.add_argument( "-e", "--export", metavar="MJCF_FILE", type=pathlib.Path, default=None, help="Path to the exported MJCF file.", ) parser.add_argument( "-f", "--force", action="store_true", default=False, help="Override the output MJCF file if it already exists (default: %(default)s).", ) parser.add_argument( "-p", "--print", action="store_true", default=False, help="Print in the stdout the exported MJCF string (default: %(default)s).", ) parser.add_argument( "-v", "--visualize", action="store_true", default=False, help="Visualize the description in the Mujoco viewer (default: %(default)s).", ) parser.add_argument( "-b", "--base-position", metavar=("x", "y", "z"), nargs=3, type=float, default=None, help="Override the base position (supports only floating-base models).", ) parser.add_argument( "-q", "--base-quaternion", metavar=("w", "x", "y", "z"), nargs=4, type=float, default=None, help="Override the base quaternion (supports only floating-base models).", ) args = parser.parse_args() # ================== # Validate arguments # ================== # Expand the path of the URDF/SDF file if not absolute. if args.description is not None: args.description = ( ( args.description if args.description.is_absolute() else pathlib.Path.cwd() / args.description ) .expanduser() .absolute() ) if not pathlib.Path(args.description).is_file(): msg = f"The URDF/SDF file '{args.description}' does not exist." parser.error(msg) sys.exit(1) # Expand the path of the output MJCF file if not absolute. if args.export is not None: args.export = ( ( args.export if args.export.is_absolute() else pathlib.Path.cwd() / args.export ) .expanduser() .absolute() ) if pathlib.Path(args.export).is_file() and not args.force: msg = "The output file '{}' already exists, use '--force' to override." parser.error(msg.format(args.export)) sys.exit(1) # ================================================ # Load the URDF/SDF file and produce a MJCF string # ================================================ mjcf_string, assets = ModelToMjcf.convert(args.description) if args.print: print(mjcf_string, flush=True) # ======================================== # Write the MJCF string to the output file # ======================================== if args.export is not None: with open(args.export, "w+", encoding="utf-8") as file: file.write(mjcf_string) # ======================================= # Visualize the MJCF in the Mujoco viewer # ======================================= if args.visualize: mj_model_helper = MujocoModelHelper.build_from_xml( mjcf_description=mjcf_string, assets=assets ) viz = MujocoVisualizer(model=mj_model_helper.model, data=mj_model_helper.data) with viz.open() as viewer: with viewer.lock(): if args.base_position is not None: mj_model_helper.set_base_position( position=np.array(args.base_position) ) if args.base_quaternion is not None: mj_model_helper.set_base_orientation( orientation=np.array(args.base_quaternion) ) viz.sync(viewer=viewer) while viewer.is_running(): time.sleep(0.500) # ============================= # Exit the program with success # ============================= sys.exit(0) ================================================ FILE: src/jaxsim/mujoco/loaders.py ================================================ import pathlib import tempfile import warnings from collections.abc import Sequence from typing import Any import jaxlie import mujoco as mj import numpy as np import rod.urdf.exporter from lxml import etree as ET from jaxsim import logging from .utils import MujocoCamera MujocoCameraType = ( MujocoCamera | Sequence[MujocoCamera] | dict[str, str] | Sequence[dict[str, str]] ) def load_rod_model( model_description: str | pathlib.Path | rod.Model, is_urdf: bool | None = None, model_name: str | None = None, ) -> rod.Model: """ Load a ROD model from a URDF/SDF file or a ROD model. Args: model_description: The URDF/SDF file or ROD model to load. is_urdf: Whether to force parsing the model description as a URDF file. model_name: The name of the model to load from the resource. Returns: rod.Model: The loaded ROD model. """ # Parse the SDF resource. sdf_element = rod.Sdf.load(sdf=model_description, is_urdf=is_urdf) # Fail if the SDF resource has no model. if len(sdf_element.models()) == 0: raise RuntimeError("Failed to find any model in the model description") # Return the model if there is only one. if len(sdf_element.models()) == 1: if model_name is not None and sdf_element.models()[0].name != model_name: raise ValueError(f"Model '{model_name}' not found in the description") return sdf_element.models()[0] # Require users to specify the model name if there are multiple models. if model_name is None: msg = "The resource has multiple models. Please specify the model name." raise ValueError(msg) # Build a dictionary of models in the resource for easy access. models = {m.name: m for m in sdf_element.models()} if model_name not in models: raise ValueError(f"Model '{model_name}' not found in the resource") return models[model_name] class ModelToMjcf: """ Class to convert a URDF/SDF file or a ROD model to a Mujoco MJCF string. """ @staticmethod def convert( model: str | pathlib.Path | rod.Model, considered_joints: list[str] | None = None, plane_normal: tuple[float, float, float] = (0, 0, 1), heightmap: bool | None = None, heightmap_samples_xy: tuple[int, int] = (101, 101), cameras: MujocoCameraType = (), ) -> tuple[str, dict[str, Any]]: """ Convert a model to a Mujoco MJCF string. Args: model: The URDF/SDF file or ROD model to convert. considered_joints: The list of joint names to consider in the conversion. plane_normal: The normal vector of the plane. heightmap: Whether to generate a heightmap. heightmap_samples_xy: The number of points in the heightmap grid. cameras: The custom cameras to add to the scene. Returns: A tuple containing the MJCF string and the dictionary of assets. """ match model: case rod.Model(): rod_model = model case str() | pathlib.Path(): # Convert the JaxSim model to a ROD model. rod_model = load_rod_model( model_description=model, is_urdf=None, model_name=None, ) case _: raise TypeError(f"Unsupported type for 'model': {type(model)}") # Convert the ROD model to MJCF. return RodModelToMjcf.convert( rod_model=rod_model, considered_joints=considered_joints, plane_normal=plane_normal, heightmap=heightmap, heightmap_samples_xy=heightmap_samples_xy, cameras=cameras, ) class RodModelToMjcf: """ Class to convert a ROD model to a Mujoco MJCF string. """ @staticmethod def assets_from_rod_model( rod_model: rod.Model, ) -> dict[str, bytes]: """ Generate a dictionary of assets from a ROD model. Args: rod_model: The ROD model to extract the assets from. Returns: dict: A dictionary of assets. """ import resolve_robotics_uri_py assets_files = dict() for link in rod_model.links(): for visual in link.visuals(): if visual.geometry.mesh and visual.geometry.mesh.uri: assets_files[visual.geometry.mesh.uri] = ( resolve_robotics_uri_py.resolve_robotics_uri( visual.geometry.mesh.uri ) ) for collision in link.collisions(): if collision.geometry.mesh and collision.geometry.mesh.uri: assets_files[collision.geometry.mesh.uri] = ( resolve_robotics_uri_py.resolve_robotics_uri( collision.geometry.mesh.uri ) ) assets = { asset_name: asset.read_bytes() for asset_name, asset in assets_files.items() } return assets @staticmethod def add_floating_joint( urdf_string: str, base_link_name: str, floating_joint_name: str = "world_to_base", ) -> str: """ Add a floating joint to a URDF string. Args: urdf_string: The URDF string to modify. base_link_name: The name of the base link to attach the floating joint. floating_joint_name: The name of the floating joint to add. Returns: str: The modified URDF string. """ with tempfile.NamedTemporaryFile(mode="w+", suffix=".urdf") as urdf_file: # Write the URDF string to a temporary file and move current position # to the beginning. urdf_file.write(urdf_string) urdf_file.seek(0) # Parse the MJCF string as XML (etree). parser = ET.XMLParser(remove_blank_text=True) tree = ET.parse(source=urdf_file, parser=parser) root: ET._Element = tree.getroot() if root.find(f".//joint[@name='{floating_joint_name}']") is not None: msg = f"The URDF already has a floating joint '{floating_joint_name}'" warnings.warn(msg, stacklevel=2) return ET.tostring(root, pretty_print=True).decode() # Create the "world" link if it doesn't exist. if root.find(".//link[@name='world']") is None: _ = ET.SubElement(root, "link", name="world") # Create the floating joint. world_to_base = ET.SubElement( root, "joint", name=floating_joint_name, type="floating" ) # Check that the base link exists. if root.find(f".//link[@name='{base_link_name}']") is None: raise ValueError(f"Link '{base_link_name}' not found in the URDF") # Attach the floating joint to the base link. ET.SubElement(world_to_base, "parent", link="world") ET.SubElement(world_to_base, "child", link=base_link_name) urdf_string = ET.tostring(root, pretty_print=True).decode() return urdf_string @staticmethod def convert( rod_model: rod.Model, considered_joints: list[str] | None = None, plane_normal: tuple[float, float, float] = (0, 0, 1), heightmap: bool | None = None, heightmap_samples_xy: tuple[int, int] = (101, 101), cameras: MujocoCameraType = (), ) -> tuple[str, dict[str, Any]]: """ Convert a ROD model to a Mujoco MJCF string. Args: rod_model: The ROD model to convert. considered_joints: The list of joint names to consider in the conversion. plane_normal: The normal vector of the plane. heightmap: Whether to generate a heightmap. heightmap_samples_xy: The number of points in the heightmap grid. cameras: The custom cameras to add to the scene. Returns: A tuple containing the MJCF string and the dictionary of assets. """ # ------------------------------------- # Convert the model description to URDF # ------------------------------------- # Consider all joints if not specified otherwise. considered_joints = set( considered_joints if considered_joints is not None else [j.name for j in rod_model.joints() if j.type != "fixed"] ) # If considered joints are passed, make sure that they are all part of the model. if considered_joints - {j.name for j in rod_model.joints()}: extra_joints = considered_joints - {j.name for j in rod_model.joints()} msg = f"Couldn't find the following joints in the model: '{extra_joints}'" raise ValueError(msg) # Create a dictionary of joints for quick access. joints_dict = {j.name: j for j in rod_model.joints()} # Convert all the joints not considered to fixed joints. for joint_name in {j.name for j in rod_model.joints()} - considered_joints: joints_dict[joint_name].type = "fixed" # Convert the ROD model to URDF. urdf_string = rod.urdf.exporter.UrdfExporter( gazebo_preserve_fixed_joints=False, pretty=True ).to_urdf_string( sdf=rod.Sdf(model=rod_model, version="1.7"), ) # ------------------------------------- # Add a floating joint if floating-base # ------------------------------------- base_link_name = rod_model.get_canonical_link() if not rod_model.is_fixed_base(): considered_joints |= {"world_to_base"} urdf_string = RodModelToMjcf.add_floating_joint( urdf_string=urdf_string, base_link_name=base_link_name, floating_joint_name="world_to_base", ) # --------------------------------------- # Inject the element in the URDF # --------------------------------------- parser = ET.XMLParser(remove_blank_text=True) root = ET.fromstring(text=urdf_string.encode(), parser=parser) mujoco_element = ( ET.SubElement(root, "mujoco") if len(root.findall("./mujoco")) == 0 else root.find("./mujoco") ) _ = ET.SubElement( mujoco_element, "compiler", balanceinertia="true", discardvisual="false", ) urdf_string = ET.tostring(root, pretty_print=True).decode() # ------------------------------ # Post-process all dummy visuals # ------------------------------ parser = ET.XMLParser(remove_blank_text=True) root: ET._Element = ET.fromstring(text=urdf_string.encode(), parser=parser) # Give a tiny radius to all dummy spheres for geometry in root.findall(".//visual/geometry[sphere]"): radius = np.fromstring( geometry.find("./sphere").attrib["radius"], sep=" ", dtype=float ) if np.allclose(radius, np.zeros(1)): geometry.find("./sphere").set("radius", "0.001") # Give a tiny volume to all dummy boxes for geometry in root.findall(".//visual/geometry[box]"): size = np.fromstring( geometry.find("./box").attrib["size"], sep=" ", dtype=float ) if np.allclose(size, np.zeros(3)): geometry.find("./box").set("size", "0.001 0.001 0.001") urdf_string = ET.tostring(root, pretty_print=True).decode() # ------------------------ # Convert the URDF to MJCF # ------------------------ # Load the URDF model into Mujoco. assets = RodModelToMjcf.assets_from_rod_model(rod_model=rod_model) mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets) # Get the joint names. mj_joint_names = { mj.mj_id2name(mj_model, mj.mjtObj.mjOBJ_JOINT, idx) for idx in range(mj_model.njnt) } # Check that the Mujoco model only has the considered joints. if mj_joint_names != considered_joints: extra1 = mj_joint_names - considered_joints extra2 = considered_joints - mj_joint_names extra_joints = extra1.union(extra2) msg = "The Mujoco model has the following extra/missing joints: '{}'" raise ValueError(msg.format(extra_joints)) # Windows locks open files, so we use mkstemp() to create a temporary file without keeping it open. with tempfile.NamedTemporaryFile( suffix=".xml", prefix=f"{rod_model.name}_", delete=False ) as tmp: temp_filename = tmp.name try: # Convert the in-memory Mujoco model to MJCF. mj.mj_saveLastXML(temp_filename, mj_model) # Parse the MJCF file as XML. parser = ET.XMLParser(remove_blank_text=True) tree = ET.parse(source=temp_filename, parser=parser) finally: pathlib.Path(temp_filename).unlink() # Get the root element. root: ET._Element = tree.getroot() # Find the element (might be the root itself). mujoco_element: ET._Element = next(iter(root.iter("mujoco"))) # -------------- # Add the frames # -------------- for frame in rod_model.frames(): frame: rod.Frame parent_name = frame.attached_to parent_element = mujoco_element.find(f".//body[@name='{parent_name}']") if parent_element is None and parent_name == base_link_name: parent_element = mujoco_element.find(".//worldbody") if parent_element is not None: quat = jaxlie.SO3.from_rpy_radians(*frame.pose.rpy).wxyz _ = ET.SubElement( parent_element, "site", name=frame.name, pos=" ".join(map(str, frame.pose.xyz)), quat=" ".join(map(str, quat)), ) else: logging.debug(f"Parent link '{parent_name}' not found") # -------------- # Add the motors # -------------- if len(mujoco_element.findall(".//actuator")) > 0: raise RuntimeError("The model already has elements.") # Add the actuator element. actuator_element = ET.SubElement(mujoco_element, "actuator") # Add a motor for each joint. for joint_element in mujoco_element.findall(".//joint"): assert ( joint_element.attrib["name"] in considered_joints ), joint_element.attrib["name"] if joint_element.attrib.get("type", "hinge") in {"free", "ball"}: continue ET.SubElement( actuator_element, "motor", name=f"{joint_element.attrib['name']}_motor", joint=joint_element.attrib["name"], gear="1", ) # --------------------------------------------- # Set full transparency of collision geometries # --------------------------------------------- parser = ET.XMLParser(remove_blank_text=True) # Get all the (optional) names of the URDF collision elements collision_names = { c.attrib["name"] for c in ET.fromstring(text=urdf_string.encode(), parser=parser).findall( ".//collision[geometry]" ) if "name" in c.attrib } # Set alpha=0 to the color of all collision elements for geometry_element in mujoco_element.findall(".//geom[@rgba]"): if geometry_element.attrib.get("name") in collision_names: r, g, b, _ = geometry_element.attrib["rgba"].split(" ") geometry_element.set("rgba", f"{r} {g} {b} 0") # ----------------------- # Create the scene assets # ----------------------- asset_element = ( ET.SubElement(mujoco_element, "asset") if len(mujoco_element.findall(".//asset")) == 0 else mujoco_element.find(".//asset") ) _ = ET.SubElement( asset_element, "texture", type="skybox", builtin="gradient", rgb1="0.3 0.5 0.7", rgb2="0 0 0", width="512", height="512", ) _ = ET.SubElement( asset_element, "texture", name="plane_texture", type="2d", builtin="checker", rgb1="0.1 0.2 0.3", rgb2="0.2 0.3 0.4", width="512", height="512", mark="cross", markrgb=".8 .8 .8", ) _ = ET.SubElement( asset_element, "material", name="plane_material", texture="plane_texture", reflectance="0.2", texrepeat="5 5", texuniform="true", ) _ = ( ET.SubElement( asset_element, "hfield", name="terrain", nrow=str(int(heightmap_samples_xy[0])), ncol=str(int(heightmap_samples_xy[1])), # The following 'size' is a placeholder, it is updated dynamically # when a hfield/heightmap is stored into MjData. size="1 1 1 1", ) if heightmap else None ) # ---------------------------------- # Populate the scene with the assets # ---------------------------------- worldbody_scene_element = ET.SubElement(mujoco_element, "worldbody") _ = ET.SubElement( worldbody_scene_element, "geom", name="floor", type="plane" if not heightmap else "hfield", size="0 0 0.05", material="plane_material", condim="3", contype="1", conaffinity="1", zaxis=" ".join(map(str, plane_normal)), **({"hfield": "terrain"} if heightmap else {}), ) _ = ET.SubElement( worldbody_scene_element, "light", name="sun", mode="fixed", directional="true", castshadow="true", pos="0 0 10", dir="0 0 -1", ) # ------------------------------------------------------- # Add a camera following the CoM of the worldbody element # ------------------------------------------------------- worldbody_element = None # Find the element of our model by searching the one that contains # all the considered joints. This is needed because there might be multiple # elements inside . for wb in mujoco_element.findall(".//worldbody"): if all( wb.find(f".//joint[@name='{j}']") is not None for j in considered_joints ): worldbody_element = wb break if worldbody_element is None: raise RuntimeError("Failed to find the element of the model") # Camera attached to the model # It can be manually copied from `python -m mujoco.viewer --mjcf=` _ = ET.SubElement( worldbody_element, "camera", name="track", mode="trackcom", pos="1.930 -2.279 0.556", xyaxes="0.771 0.637 0.000 -0.116 0.140 0.983", fovy="60", ) # Add user-defined camera. for camera in cameras if isinstance(cameras, Sequence) else [cameras]: mj_camera = ( camera if isinstance(camera, MujocoCamera) else MujocoCamera.build(**camera) ) _ = ET.SubElement(worldbody_element, "camera", mj_camera.asdict()) # ------------------------------------------------ # Add a light following the CoM of the first link # ------------------------------------------------ if not rod_model.is_fixed_base(): # Light attached to the model _ = ET.SubElement( worldbody_element, "light", name="light_model", mode="targetbodycom", target=worldbody_element.find(".//body").attrib["name"], directional="false", castshadow="true", pos="1 0 5", ) # -------------------------------- # Return the resulting MJCF string # -------------------------------- mjcf_string = ET.tostring(root, pretty_print=True).decode() return mjcf_string, assets class UrdfToMjcf: """ Class to convert a URDF file to a Mujoco MJCF string. """ @staticmethod def convert( urdf: str | pathlib.Path, considered_joints: list[str] | None = None, model_name: str | None = None, plane_normal: tuple[float, float, float] = (0, 0, 1), heightmap: bool | None = None, cameras: MujocoCameraType = (), ) -> tuple[str, dict[str, Any]]: """ Convert a URDF file to a Mujoco MJCF string. Args: urdf: The URDF file to convert. considered_joints: The list of joint names to consider in the conversion. model_name: The name of the model to convert. plane_normal: The normal vector of the plane. heightmap: Whether to generate a heightmap. cameras: The list of cameras to add to the scene. Returns: tuple: A tuple containing the MJCF string and the assets dictionary. """ logging.warning("This method is deprecated. Use 'ModelToMjcf.convert' instead.") # Get the ROD model. rod_model = load_rod_model( model_description=urdf, is_urdf=True, model_name=model_name, ) # Convert the ROD model to MJCF. return RodModelToMjcf.convert( rod_model=rod_model, considered_joints=considered_joints, plane_normal=plane_normal, heightmap=heightmap, cameras=cameras, ) class SdfToMjcf: """ Class to convert a SDF file to a Mujoco MJCF string. """ @staticmethod def convert( sdf: str | pathlib.Path, considered_joints: list[str] | None = None, model_name: str | None = None, plane_normal: tuple[float, float, float] = (0, 0, 1), heightmap: bool | None = None, cameras: MujocoCameraType = (), ) -> tuple[str, dict[str, Any]]: """ Convert a SDF file to a Mujoco MJCF string. Args: sdf: The SDF file to convert. considered_joints: The list of joint names to consider in the conversion. model_name: The name of the model to convert. plane_normal: The normal vector of the plane. heightmap: Whether to generate a heightmap. cameras: The list of cameras to add to the scene. Returns: tuple: A tuple containing the MJCF string and the assets dictionary. """ logging.warning("This method is deprecated. Use 'ModelToMjcf.convert' instead.") # Get the ROD model. rod_model = load_rod_model( model_description=sdf, is_urdf=False, model_name=model_name, ) # Convert the ROD model to MJCF. return RodModelToMjcf.convert( rod_model=rod_model, considered_joints=considered_joints, plane_normal=plane_normal, heightmap=heightmap, cameras=cameras, ) ================================================ FILE: src/jaxsim/mujoco/model.py ================================================ from __future__ import annotations import functools import pathlib from collections.abc import Callable, Sequence from typing import Any import mujoco as mj import numpy as np import numpy.typing as npt import xmltodict from scipy.spatial.transform import Rotation import jaxsim.typing as jtp HeightmapCallable = Callable[[jtp.FloatLike, jtp.FloatLike], jtp.FloatLike] class MujocoModelHelper: """ Helper class to create and interact with Mujoco models and data objects. """ def __init__(self, model: mj.MjModel, data: mj.MjData | None = None) -> None: """ Initialize the MujocoModelHelper object. Args: model: A Mujoco model object. data: A Mujoco data object. If None, a new one will be created. """ self.model = model self.data = data if data is not None else mj.MjData(self.model) # Populate the data with kinematics. mj.mj_forward(self.model, self.data) # Keep the cache of this method local to improve GC. self.mask_qpos = functools.cache(self._mask_qpos) @staticmethod def build_from_xml( mjcf_description: str | pathlib.Path, assets: dict[str, Any] | None = None, heightmap: HeightmapCallable | None = None, heightmap_name: str = "terrain", heightmap_radius_xy: tuple[float, float] = (1.0, 1.0), ) -> MujocoModelHelper: """ Build a Mujoco model from an MJCF description. Args: mjcf_description: A string containing the XML description of the Mujoco model or a path to a file containing the XML description. assets: An optional dictionary containing the assets of the model. heightmap: A function in two variables that returns the height of a terrain in the specified coordinate point. heightmap_name: The default name of the heightmap in the MJCF description to load the corresponding configuration. heightmap_radius_xy: The extension of the heightmap in the x-y surface corresponding to the plane over which the grid of the sampled heightmap is generated. Returns: A MujocoModelHelper object. """ # Read the XML description if it is a path to file. mjcf_description = ( mjcf_description.read_text() if isinstance(mjcf_description, pathlib.Path) else mjcf_description ) if heightmap is None: hfield = None else: mjcf_description_dict = xmltodict.parse(xml_input=mjcf_description) # Create a dictionary of all hfield configurations from the MJCF. hfields = mjcf_description_dict["mujoco"]["asset"].get("hfield", []) hfields = hfields if isinstance(hfields, list) else [hfields] hfields_dict = {hfield["@name"]: hfield for hfield in hfields} if heightmap_name not in hfields_dict: raise ValueError(f"Heightmap '{heightmap_name}' not found in MJCF") hfield_element = hfields_dict[heightmap_name] # Generate the hfield by sampling the heightmap function. hfield = generate_hfield( heightmap=heightmap, samples_xy=(int(hfield_element["@nrow"]), int(hfield_element["@ncol"])), radius_xy=heightmap_radius_xy, ) # Update dynamically the '/asset/hfield[@name=heightmap_name]@size' attribute # with the information of the sampled points. # This is necessary for correctly rendering the heightmap over the # specified xy area with the correct z elevation. size = [float(el) for el in hfield_element["@size"].split(" ")] size[0], size[1] = heightmap_radius_xy size[2] = 1.0 # The following could be zero but Mujoco complains if it's exactly zero. size[3] = max(0.000_001, -min(hfield)) # Replace the 'size' attribute. hfields_dict[heightmap_name]["@size"] = " ".join(str(el) for el in size) # Update the hfield elements of the original MJCF. # Only the hfield corresponding to 'heightmap_name' was actually edited. mjcf_description_dict["mujoco"]["asset"]["hfield"] = list( hfields_dict.values() ) # Serialize the updated MJCF to XML. mjcf_description = xmltodict.unparse( input_dict=mjcf_description_dict, pretty=True ) # Create the Mujoco model from the XML and, optionally, the dictionary of assets. model = mj.MjModel.from_xml_string(xml=mjcf_description, assets=assets) data = mj.MjData(model) # Store the sampled heightmap into the Mujoco model. if heightmap is not None: assert hfield is not None model.hfield_data = hfield return MujocoModelHelper(model=model, data=data) def time(self) -> float: """Return the simulation time.""" return self.data.time def timestep(self) -> float: """Return the simulation timestep.""" return self.model.opt.timestep def gravity(self) -> npt.NDArray: """Return the 3D gravity vector.""" return np.array([0, 0, self.model.gravity]) # ========================= # Methods for the base link # ========================= def is_floating_base(self) -> bool: """Return true if the model is floating-base.""" # A body with no joints is considered a fixed-base model. # In fact, in mujoco, a floating-base model has a 6 DoFs first joint. if self.number_of_joints() == 0: return False # We just check that the first joint has 6 DoFs. joint0_type = self.model.jnt_type[0] return joint0_type == mj.mjtJoint.mjJNT_FREE def is_fixed_base(self) -> bool: """Return true if the model is fixed-base.""" return not self.is_floating_base() def base_link(self) -> str: """Return the name of the base link.""" return mj.mj_id2name( self.model, mj.mjtObj.mjOBJ_BODY, 0 if self.is_fixed_base() else 1 ) def base_position(self) -> npt.NDArray: """Return the 3D position of the base link.""" return ( self.data.qpos[:3] if self.is_floating_base() else self.body_position(body_name=self.base_link()) ) def base_orientation(self, dcm: bool = False) -> npt.NDArray: """Return the orientation of the base link.""" return ( ( np.reshape(self.data.xmat[0], newshape=(3, 3)) if dcm else self.data.xquat[0] ) if self.is_floating_base() else self.body_orientation(body_name=self.base_link(), dcm=dcm) ) def set_base_position(self, position: npt.NDArray) -> None: """Set the 3D position of the base link.""" if self.is_fixed_base(): raise ValueError("The position of a fixed-base model cannot be set.") position = np.atleast_1d(np.array(position).squeeze()) if position.size != 3: raise ValueError(f"Wrong position size ({position.size})") self.data.qpos[:3] = position def set_base_orientation(self, orientation: npt.NDArray, dcm: bool = False) -> None: """Set the 3D position of the base link.""" if self.is_fixed_base(): raise ValueError("The orientation of a fixed-base model cannot be set.") orientation = ( np.atleast_2d(np.array(orientation).squeeze()) if dcm else np.atleast_1d(np.array(orientation).squeeze()) ) if orientation.shape != ((4,) if not dcm else (3, 3)): raise ValueError(f"Wrong orientation shape {orientation.shape}") def is_quaternion(Q): return np.allclose(np.linalg.norm(Q), 1.0) def is_dcm(R): return np.allclose(np.linalg.det(R), 1.0) and np.allclose( R.T @ R, np.eye(3) ) if not (is_quaternion(orientation) if not dcm else is_dcm(orientation)): raise ValueError("The orientation is not a valid element of SO(3)") W_Q_B = ( Rotation.from_matrix(orientation).as_quat( canonical=True, scalar_first=False ) if dcm else orientation ) self.data.qpos[3:7] = W_Q_B # ================== # Methods for joints # ================== def number_of_joints(self) -> int: """Return the number of joints in the model.""" return self.model.njnt def number_of_dofs(self) -> int: """Return the number of DoFs in the model.""" return self.model.nq def joint_names(self) -> list[str]: """Return the names of the joints in the model.""" return [ mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_JOINT, idx) for idx in range(0 if self.is_fixed_base() else 1, self.number_of_joints()) ] def joint_dofs(self, joint_name: str) -> int: """Return the number of DoFs of a joint.""" if joint_name not in self.joint_names(): raise ValueError(f"Joint '{joint_name}' not found") return self.data.joint(joint_name).qpos.size def joint_position(self, joint_name: str) -> npt.NDArray: """Return the position of a joint.""" if joint_name not in self.joint_names(): raise ValueError(f"Joint '{joint_name}' not found") return self.data.joint(joint_name).qpos def joint_positions(self, joint_names: list[str] | None = None) -> npt.NDArray: """Return the positions of the joints.""" joint_names = joint_names if joint_names is not None else self.joint_names() return np.hstack( [self.joint_position(joint_name) for joint_name in joint_names] ) def set_joint_position( self, joint_name: str, position: npt.NDArray | float ) -> None: """Set the position of a joint.""" position = np.atleast_1d(np.array(position).squeeze()) if position.size != self.joint_dofs(joint_name=joint_name): raise ValueError( f"Wrong position size ({position.size}) of " f"{self.joint_dofs(joint_name=joint_name)}-DoFs joint '{joint_name}'." ) idx = mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, joint_name) offset = self.model.jnt_qposadr[idx] sl = np.s_[offset : offset + self.joint_dofs(joint_name=joint_name)] self.data.qpos[sl] = position def set_joint_positions( self, joint_names: Sequence[str], positions: npt.NDArray | list[npt.NDArray] ) -> None: """Set the positions of multiple joints.""" mask = self.mask_qpos(joint_names=tuple(joint_names)) self.data.qpos[mask] = positions # ================== # Methods for bodies # ================== def number_of_bodies(self) -> int: """Return the number of bodies in the model.""" return self.model.nbody def body_names(self) -> list[str]: """Return the names of the bodies in the model.""" return [ mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_BODY, idx) for idx in range(self.number_of_bodies()) ] def body_position(self, body_name: str) -> npt.NDArray: """Return the position of a body.""" if body_name not in self.body_names(): raise ValueError(f"Body '{body_name}' not found") return self.data.body(body_name).xpos def body_orientation(self, body_name: str, dcm: bool = False) -> npt.NDArray: """Return the orientation of a body.""" if body_name not in self.body_names(): raise ValueError(f"Body '{body_name}' not found") return ( self.data.body(body_name).xmat if dcm else self.data.body(body_name).xquat ) # ====================== # Methods for geometries # ====================== def number_of_geometries(self) -> int: """Return the number of geometries in the model.""" return self.model.ngeom def geometry_names(self) -> list[str]: """Return the names of the geometries in the model.""" return [ mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_GEOM, idx) for idx in range(self.number_of_geometries()) ] def geometry_position(self, geometry_name: str) -> npt.NDArray: """Return the position of a geometry.""" if geometry_name not in self.geometry_names(): raise ValueError(f"Geometry '{geometry_name}' not found") return self.data.geom(geometry_name).xpos def geometry_orientation( self, geometry_name: str, dcm: bool = False ) -> npt.NDArray: """Return the orientation of a geometry.""" if geometry_name not in self.geometry_names(): raise ValueError(f"Geometry '{geometry_name}' not found") R = np.reshape(self.data.geom(geometry_name).xmat, newshape=(3, 3)) if dcm: return R q_xyzw = Rotation.from_matrix(R).as_quat(canonical=True, scalar_first=False) return q_xyzw # =============== # Private methods # =============== def _mask_qpos(self, joint_names: tuple[str, ...]) -> npt.NDArray: """ Create a mask to access the DoFs of the desired `joint_names` in the `qpos` array. Args: joint_names: A tuple containing the names of the joints. Returns: A 1D array containing the indices of the `qpos` array to access the DoFs of the desired `joint_names`. Note: This method takes a tuple of strings because we cache the output mask for each combination of joint names. We need a hashable object for the cache. """ # Get the indices of the joints in `joint_names`. idxs = [ mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, joint_name) for joint_name in joint_names ] # We first get the index of each joint in the qpos array, and for those that # have multiple DoFs, we expand their mask by appending new elements. # Finally, we flatten the list of arrays to a single array, that is the # final qpos mask accessing all the DoFs of the desired `joint_names`. return np.atleast_1d( np.hstack( [ np.array( [ self.model.jnt_qposadr[idx] + i for i in range(self.joint_dofs(joint_name=joint_name)) ] ) for idx, joint_name in zip(idxs, joint_names, strict=True) ] ).squeeze() ) def generate_hfield( heightmap: HeightmapCallable, samples_xy: tuple[int, int] = (11, 11), radius_xy: tuple[float, float] = (1.0, 1.0), ) -> npt.NDArray: """ Generate an array with elevation points sampled from a heightmap function. The map will have the following format: ``` heightmap[0, 0] heightmap[0, 1] ... heightmap[0, size[1]-1] heightmap[1, 0] heightmap[1, 1] ... heightmap[1, size[1]-1] ... heightmap[size[0]-1, 0] heightmap[size[0]-1, 1] ... heightmap[size[0]-1, size[1]-1] ``` Args: heightmap: A function that takes two arguments (x, y) and returns the height at that point. samples_xy: A tuple of two integers representing the size of the grid. radius_xy: A tuple of two floats representing extension of the heightmap in the x-y surface corresponding to the area over which the grid of the sampled heightmap is generated. Returns: A flat array of the sampled terrain heightmap. """ # Generate the grid. x = np.linspace(-radius_xy[0], radius_xy[0], samples_xy[0]) y = np.linspace(-radius_xy[1], radius_xy[1], samples_xy[1]) # Generate the heightmap. return np.array([[heightmap(xi, yi) for xi in x] for yi in y]).flatten() ================================================ FILE: src/jaxsim/mujoco/utils.py ================================================ from __future__ import annotations import dataclasses from collections.abc import Sequence import mujoco as mj import numpy as np import numpy.typing as npt from scipy.spatial.transform import Rotation from .model import MujocoModelHelper def mujoco_data_from_jaxsim( mujoco_model: mj.MjModel, jaxsim_model, jaxsim_data, mujoco_data: mj.MjData | None = None, update_removed_joints: bool = True, ) -> mj.MjData: """ Create a Mujoco data object from a JaxSim model and data objects. Args: mujoco_model: The Mujoco model object corresponding to the JaxSim model. jaxsim_model: The JaxSim model object from which the Mujoco model was created. jaxsim_data: The JaxSim data object containing the state of the model. mujoco_data: An optional Mujoco data object. If None, a new one will be created. update_removed_joints: If True, the positions of the joints that have been removed during the model reduction process will be set to their initial values. Returns: The Mujoco data object containing the state of the JaxSim model. Note: This method is useful to initialize a Mujoco data object used for visualization with the state of a JaxSim model. In particular, this function takes care of initializing the positions of the joints that have been removed during the model reduction process. After the initial creation of the Mujoco data object, it's faster to update the state using an external MujocoModelHelper object. """ # The package `jaxsim.mujoco` is supposed to be jax-independent. # We import all the JaxSim resources privately. import jaxsim.api as js if not isinstance(jaxsim_model, js.model.JaxSimModel): raise ValueError("The `jaxsim_model` argument must be a JaxSimModel object.") if not isinstance(jaxsim_data, js.data.JaxSimModelData): raise ValueError("The `jaxsim_data` argument must be a JaxSimModelData object.") # Create the helper to operate on the Mujoco model and data. model_helper = MujocoModelHelper(model=mujoco_model, data=mujoco_data) # If the model is fixed-base, the Mujoco model won't have the joint corresponding # to the floating base, and the helper would raise an exception. if jaxsim_model.floating_base(): # Set the model position. model_helper.set_base_position(position=np.array(jaxsim_data.base_position)) # Set the model orientation. model_helper.set_base_orientation( orientation=np.array(jaxsim_data.base_orientation) ) # Set the joint positions. if jaxsim_model.dofs() > 0: model_helper.set_joint_positions( joint_names=list(jaxsim_model.joint_names()), positions=np.array(jaxsim_data.joint_positions), ) # Updating these joints is not necessary after the first time. # Users can disable this update after initialization. if update_removed_joints: # Create a dictionary with the joints that have been removed for various reasons # (like link lumping due to model reduction). joints_removed_dict = { j.name: j for j in jaxsim_model.description._joints_removed if j.name not in set(jaxsim_model.joint_names()) } # Set the positions of the removed joints. _ = [ model_helper.set_joint_position( position=joints_removed_dict[joint_name].initial_position, joint_name=joint_name, ) # Select all original joint that have been removed from the JaxSim model # that are still present in the Mujoco model. for joint_name in joints_removed_dict if joint_name in model_helper.joint_names() ] # Return the mujoco data with updated kinematics. mj.mj_forward(mujoco_model, model_helper.data) return model_helper.data @dataclasses.dataclass class MujocoCamera: """ Helper class storing parameters of a Mujoco camera. Refer to the official documentation for more details: https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-camera """ mode: str = "fixed" target: str | None = None fovy: str = "45" pos: str = "0 0 0" quat: str | None = None axisangle: str | None = None xyaxes: str | None = None zaxis: str | None = None euler: str | None = None name: str | None = None @classmethod def build(cls, **kwargs) -> MujocoCamera: """ Build a Mujoco camera from a dictionary. """ if not all(isinstance(value, str) for value in kwargs.values()): raise ValueError(f"Values must be strings: {kwargs}") return cls(**kwargs) @staticmethod def build_from_target_view( camera_name: str, mode: str = "fixed", lookat: Sequence[float | int] | npt.NDArray = (0, 0, 0), distance: float | int | npt.NDArray = 3, azimuth: float | int | npt.NDArray = 90, elevation: float | int | npt.NDArray = -45, fovy: float | int | npt.NDArray = 45, degrees: bool = True, **kwargs, ) -> MujocoCamera: """ Create a custom camera that looks at a target point. Note: The choice of the parameters is easier if we imagine to consider a target frame `T` whose origin is located over the lookat point and having the same orientation of the world frame `W`. We also introduce a camera frame `C` whose origin is located over the lower-left corner of the image, and having the x-axis pointing right and the y-axis pointing up in image coordinates. The camera renders what it sees in the -z direction of frame `C`. Args: camera_name: The name of the camera. mode: Camera positioning mode: - **"fixed"**: Fixed position and orientation relative to the body. - **"track"**: Fixed offset from the body in world coordinates, constant orientation. - **"trackcom"**: Like `"track"`, but relative to the center of mass of the subtree. - **"targetbody"**: Fixed position in body frame, oriented toward a target body. - **"targetbodycom"**: Like `"targetbody"`, but targets the subtree's center of mass. lookat: The target point to look at (origin of `T`). distance: The distance from the target point (displacement between the origins of `T` and `C`). azimuth: The rotation around z of the camera. With an angle of 0, the camera would loot at the target point towards the positive x-axis of `T`. elevation: The rotation around the x-axis of the camera frame `C`. Note that if you want to lift the view angle, the elevation is negative. fovy: The field of view of the camera. degrees: Whether the angles are in degrees or radians. **kwargs: Additional camera parameters. Returns: The custom camera. """ # Start from a frame whose origin is located over the lookat point. # We initialize a -90 degrees rotation around the z-axis because due to # the default camera coordinate system (x pointing right, y pointing up). W_H_C = np.eye(4) W_H_C[0:3, 3] = np.array(lookat) W_H_C[0:3, 0:3] = Rotation.from_euler( seq="ZX", angles=[-90, 90], degrees=True ).as_matrix() # Process the azimuth. R_az = Rotation.from_euler(seq="Y", angles=azimuth, degrees=degrees).as_matrix() W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_az # Process elevation. R_el = Rotation.from_euler( seq="X", angles=elevation, degrees=degrees ).as_matrix() W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_el # Process distance. tf_distance = np.eye(4) tf_distance[2, 3] = distance W_H_C = W_H_C @ tf_distance # Extract the position and the quaternion. p = W_H_C[0:3, 3] Q = Rotation.from_matrix(W_H_C[0:3, 0:3]).as_quat(scalar_first=True) return MujocoCamera.build( name=camera_name, mode=mode, fovy=str(fovy if degrees else np.rad2deg(fovy)), pos=" ".join(p.astype(str).tolist()), quat=" ".join(Q.astype(str).tolist()), **kwargs, ) def asdict(self) -> dict[str, str]: """ Convert the camera to a dictionary. """ return {k: v for k, v in dataclasses.asdict(self).items() if v is not None} ================================================ FILE: src/jaxsim/mujoco/visualizer.py ================================================ import contextlib import pathlib from collections.abc import Iterator, Sequence import mediapy as media import mujoco as mj import mujoco.viewer import numpy as np import numpy.typing as npt from scipy.spatial.transform import Rotation class MujocoVideoRecorder: """ Video recorder for the MuJoCo passive viewer. """ def __init__( self, model: list[mj.MjModel] | mj.MjModel, data: list[mj.MjData] | mj.MjData, fps: int = 30, width: int | None = None, height: int | None = None, **kwargs, ) -> None: """ Initialize the Mujoco video recorder. Args: model: The Mujoco model. data: The Mujoco data. fps: The frames per second. width: The width of the video. height: The height of the video. **kwargs: Additional arguments for the renderer. """ if isinstance(model, mj.MjModel): single_model = model elif isinstance(model, list) and len(model) == 1: single_model = model[0] else: raise ValueError( "Model must be a single instance of mj.MjModel or a list with at least one element." ) width = width if width is not None else single_model.vis.global_.offwidth height = height if height is not None else single_model.vis.global_.offheight if single_model.vis.global_.offwidth != width: single_model.vis.global_.offwidth = width if single_model.vis.global_.offheight != height: single_model.vis.global_.offheight = height self.fps = fps self.frames: list[npt.NDArray] = [] self.data: list[mj.MjData] | mj.MjData | None = None self.model: list[mj.MjModel] | mj.MjModel | None = None self.reset(model=model, data=data) self.renderer = mujoco.Renderer( model=single_model, **(dict(width=width, height=height) | kwargs), ) def visualize_frame( self, frame_pose: list[float] | npt.NDArray | None = None ) -> None: """ Add visualization of a static frame. Args: frame_pose: The pose of a static frame to visualize as [x, y, z, roll, pitch, yaw]. """ scene = self.renderer.scene # Three free slots are needed for the axes (x, y, z). if scene.ngeom + 3 > scene.maxgeom: return # Read position and RPY orientation if not frame_pose: return try: x, y, z, roll, pitch, yaw = frame_pose except Exception as e: raise ValueError( "Frame pose elements must be a 6D list: 'x y z roll pitch yaw'" ) from e mat = Rotation.from_euler("xyz", [roll, pitch, yaw], degrees=False).as_matrix() origin = np.array([x, y, z]) length = 0.2 # length of axis cylinders radius = 0.01 # slim radius for cylinders for axis, color in zip( range(3), [(1, 0, 0, 1), (0, 1, 0, 1), (0, 0, 1, 1)], strict=True ): if scene.ngeom >= scene.maxgeom: break axis_dir = mat[:, axis] geom = scene.geoms[scene.ngeom] # Cylinder position is centered at origin + half length along axis pos = origin + axis_dir * length * 0.5 # Build rotation matrix for cylinder aligned with axis_dir # MuJoCo's cylinder local axis is along z-axis def rot_from_z(v: np.ndarray) -> np.ndarray: v = v / np.linalg.norm(v) z_axis = np.array([0, 0, 1]) if np.allclose(v, z_axis): return np.eye(3) if np.allclose(v, -z_axis): return np.diag([1, -1, -1]) cross = np.cross(z_axis, v) dot = np.dot(z_axis, v) skew = np.array( [ [0, -cross[2], cross[1]], [cross[2], 0, -cross[0]], [-cross[1], cross[0], 0], ] ) R = np.eye(3) + skew + skew @ skew * (1 / (1 + dot)) return R R = rot_from_z(axis_dir) mat_flat = R.flatten() mj.mjv_initGeom( geom=geom, type=mj.mjtGeom.mjGEOM_CYLINDER, # The `size` arguments takes three positional arguments. # In the cylinder case, the first two are the radius and half-length, # and the third is not used (set to 0.0). size=np.array([radius, length * 0.5, 0.0]), rgba=np.array(color), pos=pos, mat=mat_flat, ) geom.category = mj.mjtCatBit.mjCAT_STATIC scene.ngeom += 1 def reset( self, model: mj.MjModel | None = None, data: list[mj.MjData] | mj.MjData | None = None, ) -> None: """Reset the model and data.""" self.frames = [] self.data = data if data is not None else self.data self.data = self.data if isinstance(self.data, list) else [self.data] self.model = model if model is not None else self.model self.model = self.model if isinstance(self.model, list) else [self.model] assert len(self.data) == len(self.model) or len(self.model) == 1, ( f"Length mismatch: len(data)={len(self.data)}, len(model)={len(self.model)}. " "They must be equal or model must have length 1." ) def render_frame( self, camera_name: str = "track", frame_pose: list[float] | npt.NDArray | None = None, ) -> npt.NDArray: """ Render a frame. Args: camera_name: The name of the camera to use for rendering. frame_pose: The pose of a static frame to visualize as [x, y, z, roll, pitch, yaw]. Returns: The rendered frame as a NumPy array. """ for idx, data in enumerate(self.data): # Use a single model for rendering if multiple data instances are provided. # Otherwise, use the data index to select the corresponding model. model = self.model[0] if len(self.model) == 1 else self.model[idx] mj.mj_forward(model, data) if idx == 0: self.renderer.update_scene(data=data, camera=camera_name) self.visualize_frame(frame_pose=frame_pose) continue mujoco.mjv_addGeoms( m=model, d=data, opt=mj.MjvOption(), pert=mj.MjvPerturb(), catmask=mj.mjtCatBit.mjCAT_DYNAMIC, scn=self.renderer.scene, ) return self.renderer.render() def record_frame( self, camera_name: str = "track", frame_pose: list[float] | npt.NDArray | None = None, ) -> None: """Store a frame in the buffer.""" frame = self.render_frame(camera_name=camera_name, frame_pose=frame_pose) self.frames.append(frame) def write_video(self, path: pathlib.Path | str, exist_ok: bool = False) -> None: """Write the video to a file.""" # Resolve the path to the video. path = pathlib.Path(path).expanduser().resolve() if path.is_dir(): raise IsADirectoryError(f"The path '{path}' is a directory.") if not exist_ok and path.is_file(): raise FileExistsError(f"The file '{path}' already exists.") media.write_video(path=path, images=np.array(self.frames), fps=self.fps) @staticmethod def compute_down_sampling(original_fps: int, target_min_fps: int) -> int: """ Return the integer down-sampling factor to reach at least the target fps. Args: original_fps: The original fps. target_min_fps: The target minimum fps. Returns: The down-sampling factor. """ down_sampling = 1 down_sampling_final = down_sampling while original_fps / (down_sampling + 1) >= target_min_fps: down_sampling = down_sampling + 1 if int(original_fps / down_sampling) == original_fps / down_sampling: down_sampling_final = down_sampling return down_sampling_final class MujocoVisualizer: """ Visualizer for the MuJoCo passive viewer. """ def __init__( self, model: mj.MjModel | None = None, data: mj.MjData | None = None ) -> None: """ Initialize the Mujoco visualizer. Args: model: The Mujoco model. data: The Mujoco data. """ self.data = data self.model = model def sync( self, viewer: mj.viewer.Handle, model: mj.MjModel | None = None, data: mj.MjData | None = None, ) -> None: """Update the viewer with the current model and data.""" data = data if data is not None else self.data model = model if model is not None else self.model mj.mj_forward(model, data) viewer.sync() def open_viewer( self, model: mj.MjModel | None = None, data: mj.MjData | None = None, show_left_ui: bool = False, ) -> mj.viewer.Handle: """Open a viewer.""" data = data if data is not None else self.data model = model if model is not None else self.model handle = mj.viewer.launch_passive( model, data, show_left_ui=show_left_ui, show_right_ui=False ) return handle @contextlib.contextmanager def open( self, model: mj.MjModel | None = None, data: mj.MjData | None = None, *, show_left_ui: bool = False, close_on_exit: bool = True, lookat: Sequence[float | int] | npt.NDArray | None = None, distance: float | int | npt.NDArray | None = None, azimuth: float | int | npt.NDArray | None = None, elevation: float | int | npt.NDArray | None = None, ) -> Iterator[mj.viewer.Handle]: """ Context manager to open the Mujoco passive viewer. Note: Refer to the Mujoco documentation for details of the camera options: https://mujoco.readthedocs.io/en/stable/XMLreference.html#visual-global """ handle = self.open_viewer(model=model, data=data, show_left_ui=show_left_ui) handle = MujocoVisualizer.setup_viewer_camera( viewer=handle, lookat=lookat, distance=distance, azimuth=azimuth, elevation=elevation, ) try: yield handle finally: _ = handle.close() if close_on_exit else None @staticmethod def setup_viewer_camera( viewer: mj.viewer.Handle, *, lookat: Sequence[float | int] | npt.NDArray | None, distance: float | int | npt.NDArray | None = None, azimuth: float | int | npt.NDArray | None = None, elevation: float | int | npt.NDArray | None = None, ) -> mj.viewer.Handle: """ Configure the initial viewpoint of the Mujoco passive viewer. Note: Refer to the Mujoco documentation for details of the camera options: https://mujoco.readthedocs.io/en/stable/XMLreference.html#visual-global Returns: The viewer with configured camera. """ if lookat is not None: lookat_array = np.array(lookat, dtype=float).squeeze() if lookat_array.size != 3: raise ValueError(lookat) viewer.cam.lookat = lookat_array if distance is not None: viewer.cam.distance = float(distance) if azimuth is not None: viewer.cam.azimuth = float(azimuth) % 360 if elevation is not None: viewer.cam.elevation = float(elevation) return viewer ================================================ FILE: src/jaxsim/parsers/__init__.py ================================================ ================================================ FILE: src/jaxsim/parsers/descriptions/__init__.py ================================================ from .collision import ( BoxCollision, CollidablePoint, CollisionShape, MeshCollision, SphereCollision, ) from .joint import JointDescription, JointGenericAxis, JointType from .link import LinkDescription from .model import ModelDescription ================================================ FILE: src/jaxsim/parsers/descriptions/collision.py ================================================ from __future__ import annotations import abc import dataclasses import jax.numpy as jnp import numpy as np import numpy.typing as npt import jaxsim.typing as jtp from jaxsim import logging from .link import LinkDescription @dataclasses.dataclass class CollidablePoint: """ Represents a collidable point associated with a parent link. Attributes: parent_link: The parent link to which the collidable point is attached. position: The position of the collidable point relative to the parent link. enabled: A flag indicating whether the collidable point is enabled for collision detection. """ parent_link: LinkDescription position: npt.NDArray = dataclasses.field(default_factory=lambda: np.zeros(3)) enabled: bool = True def change_link( self, new_link: LinkDescription, new_H_old: npt.NDArray ) -> CollidablePoint: """ Move the collidable point to a new parent link. Args: new_link (LinkDescription): The new parent link to which the collidable point is moved. new_H_old (npt.NDArray): The transformation matrix from the new link's frame to the old link's frame. Returns: CollidablePoint: A new collidable point associated with the new parent link. """ msg = f"Moving collidable point: {self.parent_link.name} -> {new_link.name}" logging.debug(msg=msg) return CollidablePoint( parent_link=new_link, position=(new_H_old @ jnp.hstack([self.position, 1.0])).squeeze()[0:3], enabled=self.enabled, ) def __hash__(self) -> int: return hash( ( hash(self.parent_link), hash(tuple(self.position.tolist())), hash(self.enabled), ) ) def __eq__(self, other: CollidablePoint) -> bool: if not isinstance(other, CollidablePoint): return False return hash(self) == hash(other) def __str__(self) -> str: return ( f"{self.__class__.__name__}(" + f"parent_link={self.parent_link.name}" + f", position={self.position}" + f", enabled={self.enabled}" + ")" ) @dataclasses.dataclass class CollisionShape(abc.ABC): """ Abstract base class for representing collision shapes. Attributes: collidable_points: A list of collidable points associated with the collision shape. """ collidable_points: tuple[CollidablePoint] def __str__(self): return ( f"{self.__class__.__name__}(" + "collidable_points=[\n " + ",\n ".join(str(cp) for cp in self.collidable_points) + "\n])" ) @dataclasses.dataclass class BoxCollision(CollisionShape): """ Represents a box-shaped collision shape. Attributes: center: The center of the box in the local frame of the collision shape. """ center: jtp.VectorLike def __hash__(self) -> int: return hash( ( hash(super()), hash(tuple(self.center.tolist())), ) ) def __eq__(self, other: BoxCollision) -> bool: if not isinstance(other, BoxCollision): return False return hash(self) == hash(other) @dataclasses.dataclass class SphereCollision(CollisionShape): """ Represents a spherical collision shape. Attributes: center: The center of the sphere in the local frame of the collision shape. """ center: jtp.VectorLike def __hash__(self) -> int: return hash( ( hash(super()), hash(tuple(self.center.tolist())), ) ) def __eq__(self, other: BoxCollision) -> bool: if not isinstance(other, BoxCollision): return False return hash(self) == hash(other) @dataclasses.dataclass class MeshCollision(CollisionShape): """ Represents a mesh-shaped collision shape. Attributes: center: The center of the mesh in the local frame of the collision shape. """ center: jtp.VectorLike def __hash__(self) -> int: return hash( ( hash(tuple(self.center.tolist())), hash(self.collidable_points), ) ) def __eq__(self, other: MeshCollision) -> bool: if not isinstance(other, MeshCollision): return False return hash(self) == hash(other) ================================================ FILE: src/jaxsim/parsers/descriptions/joint.py ================================================ from __future__ import annotations import dataclasses from typing import ClassVar import jax_dataclasses import numpy as np import jaxsim.typing as jtp from jaxsim.utils import JaxsimDataclass, Mutability from .link import LinkDescription @dataclasses.dataclass(frozen=True) class JointType: """ Enumeration of joint types. """ Fixed: ClassVar[int] = 0 Revolute: ClassVar[int] = 1 Prismatic: ClassVar[int] = 2 @jax_dataclasses.pytree_dataclass class JointGenericAxis: """ A joint requiring the specification of a 3D axis. """ # The axis of rotation or translation of the joint (must have norm 1). axis: jtp.Vector def __hash__(self) -> int: return hash(tuple(self.axis.tolist())) def __eq__(self, other: JointGenericAxis) -> bool: if not isinstance(other, JointGenericAxis): return False return hash(self) == hash(other) @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) class JointDescription(JaxsimDataclass): """ In-memory description of a robot link. Attributes: name: The name of the joint. axis: The axis of rotation or translation for the joint. pose: The pose transformation matrix of the joint. jtype: The type of the joint. child: The child link attached to the joint. parent: The parent link attached to the joint. index: An optional index for the joint. friction_static: The static friction coefficient for the joint. friction_viscous: The viscous friction coefficient for the joint. position_limit_damper: The damper coefficient for position limits. position_limit_spring: The spring coefficient for position limits. position_limit: The position limits for the joint. initial_position: The initial position of the joint. """ name: jax_dataclasses.Static[str] axis: jtp.Vector pose: jtp.Matrix jtype: jax_dataclasses.Static[jtp.IntLike] child: LinkDescription = dataclasses.dataclass(repr=False) parent: LinkDescription = dataclasses.dataclass(repr=False) index: jtp.IntLike | None = None friction_static: jtp.FloatLike = 0.0 friction_viscous: jtp.FloatLike = 0.0 position_limit_damper: jtp.FloatLike = 0.0 position_limit_spring: jtp.FloatLike = 0.0 position_limit: tuple[jtp.FloatLike, jtp.FloatLike] = (0.0, 0.0) initial_position: jtp.FloatLike | jtp.VectorLike = 0.0 motor_inertia: jtp.FloatLike = 0.0 motor_viscous_friction: jtp.FloatLike = 0.0 motor_gear_ratio: jtp.FloatLike = 1.0 def __post_init__(self) -> None: if self.axis is not None: with self.mutable_context( mutability=Mutability.MUTABLE, restore_after_exception=False ): norm_of_axis = np.linalg.norm(self.axis) self.axis = self.axis / norm_of_axis def __eq__(self, other: JointDescription) -> bool: if not isinstance(other, JointDescription): return False return hash(self) == hash(other) def __hash__(self) -> int: from jaxsim.utils.wrappers import HashedNumpyArray return hash( ( hash(self.name), HashedNumpyArray.hash_of_array(self.axis), HashedNumpyArray.hash_of_array(self.pose), hash(int(self.jtype)), hash(self.child), hash(self.parent), hash(int(self.index)) if self.index is not None else 0, HashedNumpyArray.hash_of_array(self.friction_static), HashedNumpyArray.hash_of_array(self.friction_viscous), HashedNumpyArray.hash_of_array(self.position_limit_damper), HashedNumpyArray.hash_of_array(self.position_limit_spring), HashedNumpyArray.hash_of_array(self.position_limit), HashedNumpyArray.hash_of_array(self.initial_position), HashedNumpyArray.hash_of_array(self.motor_inertia), HashedNumpyArray.hash_of_array(self.motor_viscous_friction), HashedNumpyArray.hash_of_array(self.motor_gear_ratio), ), ) ================================================ FILE: src/jaxsim/parsers/descriptions/link.py ================================================ from __future__ import annotations import dataclasses import jax.numpy as jnp import jax_dataclasses import numpy as np from jax_dataclasses import Static import jaxsim.typing as jtp from jaxsim.math import Adjoint from jaxsim.utils import JaxsimDataclass @jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) class LinkDescription(JaxsimDataclass): """ In-memory description of a robot link. Attributes: name: The name of the link. mass: The mass of the link. inertia: The inertia tensor of the link. index: An optional index for the link (it gets automatically assigned). parent: The parent link of this link. pose: The pose transformation matrix of the link. children: The children links. """ name: Static[str] mass: float = dataclasses.field(repr=False) inertia: jtp.Matrix = dataclasses.field(repr=False) index: int | None = None parent_name: Static[str | None] = dataclasses.field(default=None, repr=False) pose: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.eye(4), repr=False) children: Static[tuple[LinkDescription]] = dataclasses.field( default_factory=list, repr=False ) def __hash__(self) -> int: from jaxsim.utils.wrappers import HashedNumpyArray return hash( ( hash(self.name), hash(float(self.mass)), HashedNumpyArray.hash_of_array(self.inertia), hash(int(self.index)) if self.index is not None else 0, HashedNumpyArray.hash_of_array(self.pose), hash(tuple(self.children)), hash(self.parent_name) if self.parent_name is not None else 0, ) ) def __eq__(self, other: LinkDescription) -> bool: if not isinstance(other, LinkDescription): return False if not ( self.name == other.name and np.allclose(self.mass, other.mass) and np.allclose(self.inertia, other.inertia) and self.index == other.index and np.allclose(self.pose, other.pose) and self.children == other.children and self.parent_name == other.parent_name ): return False return True @property def name_and_index(self) -> str: """ Get a formatted string with the link's name and index. Returns: str: The formatted string. """ return f"#{self.index}_<{self.name}>" def lump_with( self, link: LinkDescription, lumped_H_removed: jtp.Matrix ) -> LinkDescription: """ Combine the current link with another link, preserving mass and inertia. Args: link: The link to combine with. lumped_H_removed: The transformation matrix between the two links. Returns: The combined link. """ # Get the 6D inertia of the link to remove. I_removed = link.inertia # Create the SE3 object. Note the inverse. r_X_l = Adjoint.from_transform(transform=lumped_H_removed, inverse=True) # Move the inertia I_removed_in_lumped_frame = r_X_l.transpose() @ I_removed @ r_X_l # Create the new combined link lumped_link = self.replace( mass=self.mass + link.mass, inertia=self.inertia + I_removed_in_lumped_frame, ) return lumped_link ================================================ FILE: src/jaxsim/parsers/descriptions/model.py ================================================ from __future__ import annotations import dataclasses import itertools from collections.abc import Sequence from jaxsim import logging from jaxsim.logging import jaxsim_warn from ..kinematic_graph import KinematicGraph, KinematicGraphTransforms, RootPose from .collision import CollidablePoint, CollisionShape from .joint import JointDescription from .link import LinkDescription @dataclasses.dataclass(frozen=True, eq=False, unsafe_hash=False) class ModelDescription(KinematicGraph): """ Intermediate representation representing the kinematic graph of a robot model. Attributes: name: The name of the model. fixed_base: Whether the model is either fixed-base or floating-base. collision_shapes: List of collision shapes associated with the model. """ name: str = None fixed_base: bool = True collision_shapes: tuple[CollisionShape, ...] = dataclasses.field( default_factory=list, repr=False ) @staticmethod def build_model_from( name: str, links: list[LinkDescription], joints: list[JointDescription], frames: list[LinkDescription] | None = None, collisions: tuple[CollisionShape, ...] = (), fixed_base: bool = False, base_link_name: str | None = None, considered_joints: Sequence[str] | None = None, model_pose: RootPose = RootPose(), ) -> ModelDescription: """ Build a model description from provided components. Args: name: The name of the model. links: List of link descriptions. joints: List of joint descriptions. frames: List of frame descriptions. collisions: List of collision shapes associated with the model. fixed_base: Indicates whether the model has a fixed base. base_link_name: Name of the base link (i.e. the root of the kinematic tree). considered_joints: List of joint names to consider (by default all joints). model_pose: Pose of the model's root (by default an identity transform). Returns: A ModelDescription instance representing the model. """ # Create the full kinematic graph. kinematic_graph = KinematicGraph.build_from( links=links, joints=joints, frames=frames, root_link_name=base_link_name, root_pose=model_pose, ) # Reduce the graph if needed. if considered_joints is not None: kinematic_graph = kinematic_graph.reduce( considered_joints=considered_joints ) # Create the object to compute forward kinematics. fk = KinematicGraphTransforms(graph=kinematic_graph) # Container of the final model's collision shapes. final_collisions: list[CollisionShape] = [] # Move and express the collision shapes of removed links to the resulting # lumped link that replace the combination of the removed link and its parent. for collision_shape in collisions: # Assume they have an unique parent link if ( len({cp.parent_link.name for cp in collision_shape.collidable_points}) != 1 ): msg = "Collision shape not currently supported (multiple parent links)" raise RuntimeError(msg) # Get the parent link of the collision shape. # Note that this link could have been lumped and we need to find the # link in which it was lumped into. parent_link_of_shape = collision_shape.collidable_points[0].parent_link # If it is part of the (reduced) graph, add it as it is... if parent_link_of_shape.name in kinematic_graph.link_names(): final_collisions.append(collision_shape) continue # ... otherwise look for the frame if parent_link_of_shape.name not in kinematic_graph.frame_names(): msg = "Parent frame '{}' of collision shape not found, ignoring shape" logging.info(msg.format(parent_link_of_shape.name)) continue # Create a new collision shape new_collision_shape = CollisionShape(collidable_points=()) final_collisions.append(new_collision_shape) # If the frame was found, update the collidable points' pose and add them # to the new collision shape. for cp in collision_shape.collidable_points: # Find the link that is part of the (reduced) model in which the # collision shape's parent was lumped into real_parent_link_name = kinematic_graph.frames_dict[ parent_link_of_shape.name ].parent_name # Change the link associated to the collidable point, updating their # relative pose moved_cp = cp.change_link( new_link=kinematic_graph.links_dict[real_parent_link_name], new_H_old=fk.relative_transform( relative_to=real_parent_link_name, name=cp.parent_link.name, ), ) # Store the updated collision. new_collision_shape.collidable_points += (moved_cp,) # Build the model model = ModelDescription( name=name, root_pose=kinematic_graph.root_pose, fixed_base=fixed_base, collision_shapes=tuple(final_collisions), root=kinematic_graph.root, joints=kinematic_graph.joints, frames=kinematic_graph.frames, _joints_removed=kinematic_graph.joints_removed, ) # Check that the root link of kinematic graph is the desired base link. assert kinematic_graph.root.name == base_link_name, kinematic_graph.root.name return model def reduce(self, considered_joints: Sequence[str]) -> ModelDescription: """ Reduce the model by removing specified joints. Args: considered_joints: Sequence of joint names to consider. Returns: A `ModelDescription` instance that only includes the considered joints. """ jaxsim_warn( "The joint order in the model description is not preserved when reducing " "the model. Consider using the `names_to_indices` method to get the correct " "order of the joints, or use the `joint_names()` method to inspect the internal joint ordering." ) if set(considered_joints) - set(self.joint_names()): extra_joints = set(considered_joints) - set(self.joint_names()) msg = f"Found joints not part of the model: {extra_joints}" raise ValueError(msg) reduced_model_description = ModelDescription.build_model_from( name=self.name, links=list(self.links_dict.values()), joints=self.joints, frames=self.frames, collisions=self.collision_shapes, fixed_base=self.fixed_base, base_link_name=next(iter(self)).name, model_pose=self.root_pose, considered_joints=considered_joints, ) # Include the unconnected/removed joints from the original model. for joint in self.joints_removed: reduced_model_description.joints_removed.append(joint) return reduced_model_description def update_collision_shape_of_link(self, link_name: str, enabled: bool) -> None: """ Enable or disable collision shapes associated with a link. Args: link_name: The name of the link. enabled: Enable or disable collision shapes associated with the link. """ if link_name not in self.link_names(): raise ValueError(link_name) for point in self.collision_shape_of_link( link_name=link_name ).collidable_points: point.enabled = enabled def collision_shape_of_link(self, link_name: str) -> CollisionShape: """ Get the collision shape associated with a specific link. Args: link_name: The name of the link. Returns: The collision shape associated with the link. """ if link_name not in self.link_names(): raise ValueError(link_name) return CollisionShape( collidable_points=[ point for shape in self.collision_shapes for point in shape.collidable_points if point.parent_link.name == link_name ] ) def all_enabled_collidable_points(self) -> list[CollidablePoint]: """ Get all enabled collidable points in the model. Returns: The list of all enabled collidable points. """ # Get iterator of all collidable points all_collidable_points = itertools.chain.from_iterable( [shape.collidable_points for shape in self.collision_shapes] ) # Return enabled collidable points return [cp for cp in all_collidable_points if cp.enabled] def __eq__(self, other: ModelDescription) -> bool: if not isinstance(other, ModelDescription): return False if not ( self.name == other.name and self.fixed_base == other.fixed_base and self.root == other.root and self.joints == other.joints and self.frames == other.frames and self.root_pose == other.root_pose ): return False return True def __hash__(self) -> int: return hash( ( hash(self.name), hash(self.fixed_base), hash(self.root), hash(tuple(self.joints)), hash(tuple(self.frames)), hash(self.root_pose), ) ) ================================================ FILE: src/jaxsim/parsers/kinematic_graph.py ================================================ from __future__ import annotations import copy import dataclasses import functools from collections.abc import Callable, Iterable, Iterator, Sequence from typing import Any import numpy as np import numpy.typing as npt import jaxsim.utils from jaxsim import logging from jaxsim.utils import Mutability from .descriptions.joint import JointDescription, JointType from .descriptions.link import LinkDescription @dataclasses.dataclass class RootPose: """ Represents the root pose in a kinematic graph. Attributes: root_position: The 3D position of the root link of the graph. root_quaternion: The quaternion representing the rotation of the root link of the graph. Note: The root link of the kinematic graph is the base link. """ root_position: npt.NDArray = dataclasses.field(default_factory=lambda: np.zeros(3)) root_quaternion: npt.NDArray = dataclasses.field( default_factory=lambda: np.array([1.0, 0, 0, 0]) ) def __hash__(self) -> int: from jaxsim.utils.wrappers import HashedNumpyArray return hash( ( HashedNumpyArray.hash_of_array(self.root_position), HashedNumpyArray.hash_of_array(self.root_quaternion), ) ) def __eq__(self, other: RootPose) -> bool: if not isinstance(other, RootPose): return False if not np.allclose(self.root_position, other.root_position): return False if not np.allclose(self.root_quaternion, other.root_quaternion): return False return True @dataclasses.dataclass(frozen=True) class KinematicGraph(Sequence[LinkDescription]): """ Class storing a kinematic graph having links as nodes and joints as edges. Attributes: root: The root node of the kinematic graph. frames: List of frames rigidly attached to the graph nodes. joints: List of joints connecting the graph nodes. root_pose: The pose of the kinematic graph's root. """ root: LinkDescription frames: list[LinkDescription] = dataclasses.field( default_factory=list, hash=False, compare=False ) joints: list[JointDescription] = dataclasses.field( default_factory=list, hash=False, compare=False ) root_pose: RootPose = dataclasses.field(default_factory=RootPose) # Private attribute storing optional additional info. _extra_info: dict[str, Any] = dataclasses.field( default_factory=dict, repr=False, hash=False, compare=False ) # Private attribute storing the unconnected joints from the parsed model and # the joints removed after model reduction. _joints_removed: list[JointDescription] = dataclasses.field( default_factory=list, repr=False, hash=False, compare=False ) @functools.cached_property def links_dict(self) -> dict[str, LinkDescription]: """ Get a dictionary of links indexed by their name. """ return {l.name: l for l in iter(self)} @functools.cached_property def frames_dict(self) -> dict[str, LinkDescription]: """ Get a dictionary of frames indexed by their name. """ return {f.name: f for f in self.frames} @functools.cached_property def joints_dict(self) -> dict[str, JointDescription]: """ Get a dictionary of joints indexed by their name. """ return {j.name: j for j in self.joints} @functools.cached_property def joints_connection_dict( self, ) -> dict[tuple[str, str], JointDescription]: """ Get a dictionary of joints indexed by the tuple (parent, child) link names. """ return {(j.parent.name, j.child.name): j for j in self.joints} def __post_init__(self) -> None: # Assign the link index by traversing the graph with BFS. # Here we assume the model being fixed-base, therefore the base link will # have index 0. We will deal with the floating base in a later stage. for index, link in enumerate(self): link.mutable(validate=False).index = index # Get the names of the links, frames, and joints. link_names = [l.name for l in self] frame_names = [f.name for f in self.frames] joint_names = [j.name for j in self.joints] # Make sure that they are unique. assert len(link_names) == len(set(link_names)) assert len(frame_names) == len(set(frame_names)) assert len(joint_names) == len(set(joint_names)) assert set(link_names).isdisjoint(set(frame_names)) assert set(link_names).isdisjoint(set(joint_names)) # Order frames with their name. super().__setattr__("frames", sorted(self.frames, key=lambda f: f.name)) # Assign the frame index following the name-based indexing. # We assume the model being fixed-base, therefore the first frame will # have last_link_idx + 1. for index, frame in enumerate(self.frames): with frame.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): frame.index = index + len(self.link_names()) # Number joints so that their index matches their child link index. # Therefore, the first joint has index 1. links_dict = {l.name: l for l in iter(self)} for joint in self.joints: with joint.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): joint.index = links_dict[joint.child.name].index # Check that joint indices are unique. assert len([j.index for j in self.joints]) == len( {j.index for j in self.joints} ) # Order joints with their indices. super().__setattr__("joints", sorted(self.joints, key=lambda j: j.index)) @staticmethod def build_from( links: list[LinkDescription], joints: list[JointDescription], frames: list[LinkDescription] | None = None, root_link_name: str | None = None, root_pose: RootPose = RootPose(), ) -> KinematicGraph: """ Build a KinematicGraph from links, joints, and frames. Args: links: A list of link descriptions. joints: A list of joint descriptions. frames: A list of frame descriptions. root_link_name: The name of the root link. If not provided, it's assumed to be the first link's name. root_pose: The root pose of the kinematic graph. Returns: The resulting kinematic graph. """ # Consider the first link as the root link if not provided. if root_link_name is None: root_link_name = links[0].name logging.debug(msg=f"Assuming '{root_link_name}' as the root link") # Couple links and joints and create the graph of links. # Note that the pose of the frames is not updated; it is the caller's # responsibility to update their pose if they want to use them. ( graph_root_node, graph_joints, graph_frames, unconnected_links, unconnected_joints, unconnected_frames, ) = KinematicGraph._create_graph( links=links, joints=joints, root_link_name=root_link_name, frames=frames ) for link in unconnected_links: logging.warning(msg=f"Ignoring unconnected link: '{link.name}'") for joint in unconnected_joints: logging.warning(msg=f"Ignoring unconnected joint: '{joint.name}'") for frame in unconnected_frames: logging.warning(msg=f"Ignoring unconnected frame: '{frame.name}'") return KinematicGraph( root=graph_root_node, joints=graph_joints, frames=graph_frames, root_pose=root_pose, _joints_removed=unconnected_joints, ) @staticmethod def _create_graph( links: list[LinkDescription], joints: list[JointDescription], root_link_name: str, frames: list[LinkDescription] | None = None, ) -> tuple[ LinkDescription, list[JointDescription], list[LinkDescription], list[LinkDescription], list[JointDescription], list[LinkDescription], ]: """ Low-level creator of kinematic graph components. Args: links: A list of parsed link descriptions. joints: A list of parsed joint descriptions. root_link_name: The name of the root link used as root node of the graph. frames: A list of parsed frame descriptions. Returns: A tuple containing the root node of the graph (defining the entire kinematic tree by iterating on its child nodes), the list of joints representing the actual graph edges, the list of frames rigidly attached to the graph nodes, the list of unconnected links, the list of unconnected joints, and the list of unconnected frames. """ # Create a dictionary that maps the link name to the link, for easy retrieval. links_dict: dict[str, LinkDescription] = { l.name: l.mutable(validate=False) for l in links } # Create an empty list of frames if not provided. frames = frames if frames is not None else [] # Create a dictionary that maps the frame name to the frame, for easy retrieval. frames_dict = {frame.name: frame for frame in frames} # Check that our parser correctly resolved the frame's parent to be a link. for frame in frames: assert frame.parent_name != "", frame assert frame.parent_name is not None, frame assert frame.parent_name != "__model__", frame assert frame.parent_name not in frames_dict, frame # =========================================================== # Populate the kinematic graph with links, joints, and frames # =========================================================== # Check the existence of the root link. if root_link_name not in links_dict: raise ValueError(root_link_name) # Reset the connections of the root link. for link in links_dict.values(): link.children = tuple() # Couple links and joints creating the kinematic graph. for joint in joints: # Get the parent and child links of the joint. parent_link = links_dict[joint.parent.name] child_link = links_dict[joint.child.name] assert child_link.name == joint.child.name assert parent_link.name == joint.parent.name # Assign link's parent. child_link.parent_name = parent_link.name # Assign link's children and make sure they are unique. if child_link.name not in {l.name for l in parent_link.children}: with parent_link.mutable_context(Mutability.MUTABLE_NO_VALIDATION): parent_link.children = (*parent_link.children, child_link) # Collect all the links of the kinematic graph. all_links_in_graph = list( KinematicGraph.breadth_first_search(root=links_dict[root_link_name]) ) # Get the names of all links in the kinematic graph. all_link_names_in_graph = [l.name for l in all_links_in_graph] # Collect all the joints of the kinematic graph. all_joints_in_graph = [ joint for joint in joints if joint.parent.name in all_link_names_in_graph and joint.child.name in all_link_names_in_graph ] # Get the names of all joints in the kinematic graph. all_joint_names_in_graph = [j.name for j in all_joints_in_graph] # Collect all the frames of the kinematic graph. # Note: our parser ensures that the parent of a frame is not another frame. all_frames_in_graph = [ frame for frame in frames if frame.parent_name in all_link_names_in_graph ] # Get the names of all frames in the kinematic graph. all_frames_names_in_graph = [f.name for f in all_frames_in_graph] # ============================ # Collect unconnected elements # ============================ # Collect all the joints that are not part of the kinematic graph. removed_joints = [j for j in joints if j.name not in all_joint_names_in_graph] for joint in removed_joints: msg = "Joint '{}' is unconnected and it will be removed" logging.debug(msg=msg.format(joint.name)) # Collect all the links that are not part of the kinematic graph. unconnected_links = [l for l in links if l.name not in all_link_names_in_graph] # Update the unconnected links by removing their children. The other properties # are left untouched, it's caller responsibility to post-process them if needed. for link in unconnected_links: link.children = tuple() msg = "Link '{}' won't be part of the kinematic graph because unconnected" logging.debug(msg=msg.format(link.name)) # Collect all the frames that are not part of the kinematic graph. unconnected_frames = [ f for f in frames if f.name not in all_frames_names_in_graph ] for frame in unconnected_frames: msg = "Frame '{}' won't be part of the kinematic graph because unconnected" logging.debug(msg=msg.format(frame.name)) return ( links_dict[root_link_name].mutable(mutable=False), list(set(joints) - set(removed_joints)), all_frames_in_graph, unconnected_links, list(set(removed_joints)), unconnected_frames, ) def reduce(self, considered_joints: Sequence[str]) -> KinematicGraph: """ Reduce the kinematic graph by removing unspecified joints. When a joint is removed, the mass and inertia of its child link are lumped with those of its parent link, obtaining a new link that combines the two. The description of the removed joint specifies the default angle (usually 0) that is considered when the joint is removed. Args: considered_joints: A list of joint names to consider. Returns: The reduced kinematic graph. """ # The current object represents the complete kinematic graph full_graph = self # Get the names of the joints to remove joint_names_to_remove = list( set(full_graph.joint_names()) - set(considered_joints) ) # Return early if there is no action to take if len(joint_names_to_remove) == 0: logging.info("The kinematic graph doesn't need to be reduced") return copy.deepcopy(self) # Check if all considered joints are part of the full kinematic graph if set(considered_joints) - {j.name for j in full_graph.joints}: extra_j = set(considered_joints) - {j.name for j in full_graph.joints} msg = f"Not all joints to consider are part of the graph ({{{extra_j}}})" raise ValueError(msg) # Extract data we need to modify from the full graph links_dict = copy.deepcopy(full_graph.links_dict) joints_dict = copy.deepcopy(full_graph.joints_dict) # Create the object to compute forward kinematics. fk = KinematicGraphTransforms(graph=full_graph) # The following steps are implemented below in order to create the reduced graph: # # 1. Lump the mass of the removed links into their parent # 2. Update the pose and parent link of joints having the removed link as parent # 3. Create the reduced graph considering the removed links as frames # 4. Resolve the pose of the frames wrt their reduced graph parent # # We name "removed link" the link to remove, and "lumped link" the new link that # combines the removed link and its parent. The lumped link will share the frame # of the removed link's parent and the inertial properties of the two links that # have been combined. # ======================================================= # 1. Lump the mass of the removed links into their parent # ======================================================= # Get all the links to remove. They will be lumped with their parent. links_to_remove = [ joint.child.name for joint_name, joint in joints_dict.items() if joint_name in joint_names_to_remove ] # Lump the mass and the inertia traversing the tree from the leaf to the root, # this way we propagate these properties back even in the case when also the # parent link of a removed joint has to be lumped with its parent. for link in reversed(full_graph): if link.name not in links_to_remove: continue # Get the link to remove and its parent, i.e. the lumped link link_to_remove = links_dict[link.name] parent_of_link_to_remove = links_dict[link.parent_name] msg = "Lumping chain: {}->({})->{}" logging.debug( msg.format( link_to_remove.name, self.joints_connection_dict[ parent_of_link_to_remove.name, link_to_remove.name ].name, parent_of_link_to_remove.name, ) ) # Lump the link lumped_link = parent_of_link_to_remove.lump_with( link=link_to_remove, lumped_H_removed=fk.relative_transform( relative_to=parent_of_link_to_remove.name, name=link_to_remove.name ), ) # Pop the original two links from the dictionary... _ = links_dict.pop(link_to_remove.name) _ = links_dict.pop(parent_of_link_to_remove.name) # ... and insert the lumped link (having the same name of the parent) links_dict[lumped_link.name] = lumped_link # Insert back in the dict an entry from the removed link name to the new # lumped link. We need this info later, when we process the remaining joints. links_dict[link_to_remove.name] = lumped_link # As a consequence of the back-insertion, we need to adjust the resulting # lumped link of links that have been removed previously. # Note: in the dictionary, only items whose key is not matching value.name # are links that have been removed. for previously_removed_link_name in { link_name for link_name, link in links_dict.items() if link_name != link.name and link.name == link_to_remove.name }: links_dict[previously_removed_link_name] = lumped_link # ============================================================================== # 2. Update the pose and parent link of joints having the removed link as parent # ============================================================================== # Find the joints having the removed links as parent joints_with_removed_parent_link = [ joints_dict[joint_name] for joint_name in considered_joints if joints_dict[joint_name].parent.name in links_to_remove ] # Update the pose of all joints having as parent link a removed link for joint in joints_with_removed_parent_link: # Update the pose. Note that after the lumping process, the dict entry # links_dict[joint.parent.name] contains the final lumped link with joint.mutable_context(mutability=Mutability.MUTABLE): joint.pose = fk.relative_transform( relative_to=links_dict[joint.parent.name].name, name=joint.name ) with joint.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): # Update the parent link joint.parent = links_dict[joint.parent.name] # =================================================================== # 3. Create the reduced graph considering the removed links as frames # =================================================================== # Get all the original links from the full graph full_graph_links_dict = copy.deepcopy(full_graph.links_dict) # Get all the final links from the reduced graph links_to_keep = [ l for link_name, l in links_dict.items() if link_name not in links_to_remove ] # Override the entries of the full graph with those of the reduced graph. # Those that are not overridden will become frames. for link in links_to_keep: full_graph_links_dict[link.name] = link # Create the reduced graph data. We pass the full list of links so that those # that are not part of the graph will be returned as frames. ( reduced_root_node, reduced_joints, reduced_frames, unconnected_links, unconnected_joints, unconnected_frames, ) = KinematicGraph._create_graph( links=list(full_graph_links_dict.values()), joints=[joints_dict[joint_name] for joint_name in considered_joints], root_link_name=full_graph.root.name, ) assert {f.name for f in self.frames}.isdisjoint( {f.name for f in unconnected_frames + reduced_frames} ) for link in unconnected_links: logging.debug(msg=f"Link '{link.name}' is unconnected and became a frame") # Create the reduced graph. reduced_graph = KinematicGraph( root=reduced_root_node, joints=reduced_joints, frames=self.frames + unconnected_links + reduced_frames, root_pose=full_graph.root_pose, _joints_removed=( self._joints_removed + unconnected_joints + [joints_dict[name] for name in joint_names_to_remove] ), ) # ================================================================ # 4. Resolve the pose of the frames wrt their reduced graph parent # ================================================================ # Build a new object to compute FK on the reduced graph. fk_reduced = KinematicGraphTransforms(graph=reduced_graph) # We need to adjust the pose of the frames since their parent link # could have been removed by the reduction process. for frame in reduced_graph.frames: # Always find the real parent link of the frame name_of_new_parent_link = fk_reduced.find_parent_link_of_frame( name=frame.name ) assert name_of_new_parent_link in reduced_graph, name_of_new_parent_link # Notify the user if the parent link has changed. if name_of_new_parent_link != frame.parent_name: msg = "New parent of frame '{}' is '{}'" logging.debug(msg=msg.format(frame.name, name_of_new_parent_link)) # Always recompute the pose of the frame, and set zero inertial params. with frame.mutable_context(jaxsim.utils.Mutability.MUTABLE_NO_VALIDATION): # Update kinematic parameters of the frame. # Note that here we compute the transform using the FK object of the # full model, so that we are sure that the kinematic is not altered. frame.pose = fk.relative_transform( relative_to=name_of_new_parent_link, name=frame.name ) # Update the parent link such that the pose is expressed in its frame. frame.parent_name = name_of_new_parent_link # Update dynamic parameters of the frame. frame.mass = 0.0 frame.inertia = np.zeros_like(frame.inertia) # Return the reduced graph. return reduced_graph def link_names(self) -> list[str]: """ Get the names of all links in the kinematic graph (i.e. the nodes). Returns: The list of link names. """ return list(self.links_dict.keys()) def joint_names(self) -> list[str]: """ Get the names of all joints in the kinematic graph (i.e. the edges). Returns: The list of joint names. """ return list(self.joints_dict.keys()) def frame_names(self) -> list[str]: """ Get the names of all frames in the kinematic graph. Returns: The list of frame names. """ return list(self.frames_dict.keys()) def print_tree(self) -> None: """ Print the tree structure of the kinematic graph. """ import pptree root_node = self.root pptree.print_tree( root_node, childattr="children", nameattr="name_and_index", horizontal=True, ) @property def joints_removed(self) -> list[JointDescription]: """ Get the list of joints removed during the graph reduction. Returns: The list of removed joints. """ return self._joints_removed @staticmethod def breadth_first_search( root: LinkDescription, sort_children: Callable[[Any], Any] | None = lambda link: link.name, ) -> Iterable[LinkDescription]: """ Perform a breadth-first search (BFS) traversal of the kinematic graph. Args: root: The root link for BFS. sort_children: A function to sort children of a node. Yields: The links in the kinematic graph in BFS order. """ # Initialize the queue with the root node. queue = [root] # We assume that nodes have unique names and mark a link as visited using # its name. This speeds up considerably object comparison. visited = [] visited.append(root.name) yield root while queue: # Extract the first element of the queue. l = queue.pop(0) # Note: sorting the links with their name so that the order of children # insertion does not matter when assigning the link index. for child in sorted(l.children, key=sort_children): if child.name in visited: continue visited.append(child.name) queue.append(child) yield child # ================= # Sequence protocol # ================= def __iter__(self) -> Iterator[LinkDescription]: yield from KinematicGraph.breadth_first_search(root=self.root) def __reversed__(self) -> Iterable[LinkDescription]: yield from reversed(list(iter(self))) def __len__(self) -> int: return len(list(iter(self))) def __contains__(self, item: str | LinkDescription) -> bool: if isinstance(item, str): return item in self.link_names() if isinstance(item, LinkDescription): return item in set(iter(self)) raise TypeError(type(item).__name__) def __getitem__(self, key: int | str) -> LinkDescription: if isinstance(key, str): if key not in self.link_names(): raise KeyError(key) return self.links_dict[key] if isinstance(key, int): if key > len(self): raise KeyError(key) return list(iter(self))[key] raise TypeError(type(key).__name__) def count(self, value: LinkDescription) -> int: """ Count the occurrences of a link in the kinematic graph. """ return list(iter(self)).count(value) def index(self, value: LinkDescription, start: int = 0, stop: int = -1) -> int: """ Find the index of a link in the kinematic graph. """ return list(iter(self)).index(value, start, stop) # ==================== # Other useful classes # ==================== @dataclasses.dataclass(frozen=True) class KinematicGraphTransforms: """ Class to compute forward kinematics on a kinematic graph. Attributes: graph: The kinematic graph on which to compute forward kinematics. """ graph: KinematicGraph _transform_cache: dict[str, npt.NDArray] = dataclasses.field( default_factory=dict, init=False, repr=False, compare=False ) _initial_joint_positions: dict[str, float] = dataclasses.field( init=False, repr=False, compare=False ) def __post_init__(self) -> None: super().__setattr__( "_initial_joint_positions", {joint.name: joint.initial_position for joint in self.graph.joints}, ) @property def initial_joint_positions(self) -> npt.NDArray: """ Get the initial joint positions of the kinematic graph. """ return np.atleast_1d( np.array(list(self._initial_joint_positions.values())) ).astype(float) @initial_joint_positions.setter def initial_joint_positions( self, positions: npt.NDArray | Sequence, joint_names: Sequence[str] | None = None, ) -> None: joint_names = ( joint_names if joint_names is not None else list(self._initial_joint_positions.keys()) ) s = np.atleast_1d(np.array(positions).squeeze()) if s.size != len(joint_names): raise ValueError(s.size, len(joint_names)) for joint_name in joint_names: if joint_name not in self._initial_joint_positions: raise ValueError(joint_name) # Clear transform cache. self._transform_cache.clear() # Update initial joint positions. for joint_name, position in zip(joint_names, s, strict=True): self._initial_joint_positions[joint_name] = position def transform(self, name: str) -> npt.NDArray: """ Compute the SE(3) transform of elements belonging to the kinematic graph. Args: name: The name of a link, a joint, or a frame. Returns: The 4x4 transform matrix of the element w.r.t. the model frame. """ # If the transform was already computed, return it. if name in self._transform_cache: return self._transform_cache[name] # If the name is a joint, compute M_H_J transform. if name in self.graph.joint_names(): # Get the joint. joint = self.graph.joints_dict[name] assert joint.name == name # Get the transform of the parent link. M_H_L = self.transform(name=joint.parent.name) # Rename the pose of the predecessor joint frame w.r.t. its parent link. L_H_pre = joint.pose # Compute the joint transform from the predecessor to the successor frame. pre_H_J = self.pre_H_suc( joint_type=joint.jtype, joint_axis=joint.axis, joint_position=self._initial_joint_positions[joint.name], ) # Compute the M_H_J transform. self._transform_cache[name] = M_H_L @ L_H_pre @ pre_H_J return self._transform_cache[name] # If the name is a link, compute M_H_L transform. if name in self.graph.link_names(): # Get the link. link = self.graph.links_dict[name] # Handle the pose between the __model__ frame and the root link. if link.name == self.graph.root.name: M_H_B = link.pose return M_H_B # Get the joint between the link and its parent. parent_joint = self.graph.joints_connection_dict[ link.parent_name, link.name ] # Get the transform of the parent joint. M_H_J = self.transform(name=parent_joint.name) # Rename the pose of the link w.r.t. its parent joint. J_H_L = link.pose # Compute the M_H_L transform. self._transform_cache[name] = M_H_J @ J_H_L return self._transform_cache[name] # It can only be a plain frame. if name not in self.graph.frame_names(): raise ValueError(name) # Get the frame. frame = self.graph.frames_dict[name] # Get the transform of the parent link. M_H_L = self.transform(name=frame.parent_name) # Rename the pose of the frame w.r.t. its parent link. L_H_F = frame.pose # Compute the M_H_F transform. self._transform_cache[name] = M_H_L @ L_H_F return self._transform_cache[name] def relative_transform(self, relative_to: str, name: str) -> npt.NDArray: """ Compute the SE(3) relative transform of elements belonging to the kinematic graph. Args: relative_to: The name of the reference element. name: The name of a link, a joint, or a frame. Returns: The 4x4 transform matrix of the element w.r.t. the desired frame. """ import jaxsim.math M_H_target = self.transform(name=name) M_H_R = self.transform(name=relative_to) # Compute the relative transform R_H_target, where R is the reference frame, # and i the frame of the desired link|joint|frame. return np.array(jaxsim.math.Transform.inverse(M_H_R)) @ M_H_target @staticmethod def pre_H_suc( joint_type: JointType, joint_axis: npt.NDArray, joint_position: float | None = None, ) -> npt.NDArray: """ Compute the SE(3) transform from the predecessor to the successor frame. Args: joint_type: The type of the joint. joint_axis: The axis of the joint. joint_position: The position of the joint. Returns: The 4x4 transform matrix from the predecessor to the successor frame. """ import jaxsim.math return np.array( jaxsim.math.supported_joint_motion(joint_type, joint_position, joint_axis) ) def find_parent_link_of_frame(self, name: str) -> str: """ Find the parent link of a frame. Args: name: The name of the frame. Returns: The name of the parent link of the frame. """ try: frame = self.graph.frames_dict[name] except KeyError as e: raise ValueError(f"Frame '{name}' not found in the kinematic graph") from e if frame.parent_name in self.graph.links_dict: return frame.parent_name if frame.parent_name in self.graph.frames_dict: return self.find_parent_link_of_frame(name=frame.parent_name) msg = f"Failed to find parent element of frame '{name}' with name '{frame.parent_name}'" raise RuntimeError(msg) ================================================ FILE: src/jaxsim/parsers/rod/__init__.py ================================================ from . import parser, utils from .parser import build_model_description, extract_model_data ================================================ FILE: src/jaxsim/parsers/rod/meshes.py ================================================ import numpy as np import trimesh VALID_AXIS = {"x": 0, "y": 1, "z": 2} def extract_points_vertices(mesh: trimesh.Trimesh) -> np.ndarray: """ Extract the vertices of a mesh as points. """ return mesh.vertices def extract_points_random_surface_sampling(mesh: trimesh.Trimesh, n) -> np.ndarray: """ Extract N random points from the surface of a mesh. Args: mesh: The mesh from which to extract points. n: The number of points to extract. Returns: The extracted points (N x 3 array). """ return mesh.sample(n) def extract_points_uniform_surface_sampling( mesh: trimesh.Trimesh, n: int ) -> np.ndarray: """ Extract N uniformly sampled points from the surface of a mesh. Args: mesh: The mesh from which to extract points. n: The number of points to extract. Returns: The extracted points (N x 3 array). """ return trimesh.sample.sample_surface_even(mesh=mesh, count=n)[0] def extract_points_select_points_over_axis( mesh: trimesh.Trimesh, axis: str, direction: str, n: int ) -> np.ndarray: """ Extract N points from a mesh along a specified axis. The points are selected based on their position along the axis. Args: mesh: The mesh from which to extract points. axis: The axis along which to extract points. direction: The direction along the axis from which to extract points. Valid values are "higher" and "lower". n: The number of points to extract. Returns: The extracted points (N x 3 array). """ dirs = {"higher": np.s_[-n:], "lower": np.s_[:n]} arr = mesh.vertices # Sort rows lexicographically first, then columnar. arr.sort(axis=0) sorted_arr = arr[dirs[direction]] return sorted_arr def extract_points_aap( mesh: trimesh.Trimesh, axis: str, upper: float | None = None, lower: float | None = None, ) -> np.ndarray: """ Extract points from a mesh along a specified axis within a specified range. The points are selected based on their position along the axis. Args: mesh: The mesh from which to extract points. axis: The axis along which to extract points. upper: The upper bound of the range. lower: The lower bound of the range. Returns: The extracted points (N x 3 array). Raises: AssertionError: If the lower bound is greater than the upper bound. """ # Check bounds. upper = upper if upper is not None else np.inf lower = lower if lower is not None else -np.inf assert lower < upper, "Invalid bounds for axis-aligned plane" # Logic. points = mesh.vertices[ (mesh.vertices[:, VALID_AXIS[axis]] >= lower) & (mesh.vertices[:, VALID_AXIS[axis]] <= upper) ] return points ================================================ FILE: src/jaxsim/parsers/rod/parser.py ================================================ import dataclasses import os import pathlib from typing import NamedTuple import jax.numpy as jnp import numpy as np import rod from jaxsim import logging from jaxsim.math import Quaternion from jaxsim.parsers import descriptions, kinematic_graph from . import utils class SDFData(NamedTuple): """ Data extracted from an SDF resource useful to build a JaxSim model. """ model_name: str fixed_base: bool base_link_name: str link_descriptions: list[descriptions.LinkDescription] joint_descriptions: list[descriptions.JointDescription] frame_descriptions: list[descriptions.LinkDescription] collision_shapes: list[descriptions.CollisionShape] sdf_model: rod.Model | None = None model_pose: kinematic_graph.RootPose = kinematic_graph.RootPose() def extract_model_data( model_description: pathlib.Path | str | rod.Model | rod.Sdf, model_name: str | None = None, is_urdf: bool | None = None, ) -> SDFData: """ Extract data from an SDF/URDF resource useful to build a JaxSim model. Args: model_description: A path to an SDF/URDF file, a string containing its content, or a pre-parsed/pre-built rod model. model_name: The name of the model to extract from the SDF resource. is_urdf: Whether to force parsing the resource as a URDF file. Automatically detected if not provided. Returns: The extracted model data. """ match model_description: case rod.Model(): sdf_model = model_description case rod.Sdf() | str() | pathlib.Path(): sdf_element = ( model_description if isinstance(model_description, rod.Sdf) else rod.Sdf.load(sdf=model_description, is_urdf=is_urdf) ) if not sdf_element.models(): raise RuntimeError("Failed to find any model in SDF resource") # Assume the SDF resource has only one model, or the desired model name is given. sdf_models = {m.name: m for m in sdf_element.models()} sdf_model = ( sdf_element.models()[0] if len(sdf_models) == 1 else sdf_models[model_name] ) # Log model name. logging.info(msg=f"Found model '{sdf_model.name}' in SDF resource") # Jaxsim supports only models compatible with URDF, i.e. those having all links # directly attached to their parent joint without additional roto-translations. # Furthermore, the following switch also post-processes frames such that their # pose is expressed wrt the parent link they are rigidly attached to. sdf_model.switch_frame_convention(frame_convention=rod.FrameConvention.Urdf) # Log type of base link. logging.debug( msg=f"Model '{sdf_model.name}' is {'fixed-base' if sdf_model.is_fixed_base() else 'floating-base'}" ) # Log detected base link. logging.debug(msg=f"Considering '{sdf_model.get_canonical_link()}' as base link") # Pose of the model if sdf_model.pose is None: model_pose = kinematic_graph.RootPose() else: W_H_M = sdf_model.pose.transform() model_pose = kinematic_graph.RootPose( root_position=W_H_M[0:3, 3], root_quaternion=Quaternion.from_dcm(dcm=W_H_M[0:3, 0:3]), ) # =========== # Parse links # =========== # Parse the links (unconnected). links = [ descriptions.LinkDescription( name=l.name, mass=float(l.inertial.mass), inertia=utils.from_sdf_inertial(inertial=l.inertial), pose=l.pose.transform() if l.pose is not None else np.eye(4), ) for l in sdf_model.links() if l.inertial.mass > 0 ] # Create a dictionary to find easily links. links_dict: dict[str, descriptions.LinkDescription] = {l.name: l for l in links} # ============ # Parse frames # ============ # Parse the frames (unconnected). frames = [ descriptions.LinkDescription( name=f.name, mass=jnp.array(0.0, dtype=float), inertia=jnp.zeros(shape=(3, 3)), parent_name=f.attached_to, pose=f.pose.transform() if f.pose is not None else jnp.eye(4), ) for f in sdf_model.frames() if f.attached_to in links_dict ] # ========================= # Process fixed-base models # ========================= # In this case, we need to get the pose of the joint that connects the base link # to the world and combine their pose. if sdf_model.is_fixed_base(): # Create a massless word link world_link = descriptions.LinkDescription( name="world", mass=0, inertia=np.zeros(shape=(6, 6)) ) # Gather joints connecting fixed-base models to the world. # TODO: the pose of this joint could be expressed wrt any arbitrary frame, # here we assume is expressed wrt the model. This also means that the # default model pose matches the pose of the fake "world" link. joints_with_world_parent = [ descriptions.JointDescription( name=j.name, parent=world_link, child=links_dict[j.child], jtype=utils.joint_to_joint_type(joint=j), axis=( np.array(j.axis.xyz.xyz) if j.axis is not None and j.axis.xyz is not None and j.axis.xyz.xyz is not None else None ), pose=j.pose.transform() if j.pose is not None else np.eye(4), ) for j in sdf_model.joints() if j.type == "fixed" and j.parent == "world" and j.child in links_dict and j.pose.relative_to in {"__model__", "world", None} ] logging.debug( f"Found joints connecting to world: {[j.name for j in joints_with_world_parent]}" ) if len(joints_with_world_parent) != 1: msg = "Found more/less than one joint connecting a fixed-base model to the world" raise ValueError(msg + f": {[j.name for j in joints_with_world_parent]}") base_link_name = joints_with_world_parent[0].child.name msg = "Combining the pose of base link '{}' with the pose of joint '{}'" logging.debug(msg.format(base_link_name, joints_with_world_parent[0].name)) # Combine the pose of the base link (child of the found fixed joint) # with the pose of the fixed joint connecting with the world. # Note: we assume it's a fixed joint and ignore any joint angle. links_dict[base_link_name].mutable(validate=False).pose = ( joints_with_world_parent[0].pose @ links_dict[base_link_name].pose ) # ============ # Parse joints # ============ # Check that all joint poses are expressed w.r.t. their parent link. for j in sdf_model.joints(): if j.pose is None: continue if j.parent == "world": if j.pose.relative_to in {"__model__", "world", None}: continue raise ValueError("Pose of fixed joint connecting to 'world' link not valid") if j.pose.relative_to != j.parent: msg = "Pose of joint '{}' is not expressed wrt its parent link '{}'" raise ValueError(msg.format(j.name, j.parent)) # Parse the joints. joints = [ descriptions.JointDescription( name=j.name, parent=links_dict[j.parent], child=links_dict[j.child], jtype=utils.joint_to_joint_type(joint=j), axis=( np.array(j.axis.xyz.xyz, dtype=float) if j.axis is not None and j.axis.xyz is not None and j.axis.xyz.xyz is not None else None ), pose=j.pose.transform() if j.pose is not None else np.eye(4), initial_position=0.0, position_limit=( float( j.axis.limit.lower if j.axis is not None and j.axis.limit is not None and j.axis.limit.lower is not None else jnp.finfo(float).min ), float( j.axis.limit.upper if j.axis is not None and j.axis.limit is not None and j.axis.limit.upper is not None else jnp.finfo(float).max ), ), friction_static=float( j.axis.dynamics.friction if j.axis is not None and j.axis.dynamics is not None and j.axis.dynamics.friction is not None else 0.0 ), friction_viscous=float( j.axis.dynamics.damping if j.axis is not None and j.axis.dynamics is not None and j.axis.dynamics.damping is not None else 0.0 ), position_limit_damper=float( j.axis.limit.dissipation if j.axis is not None and j.axis.limit is not None and j.axis.limit.dissipation is not None else os.environ.get("JAXSIM_JOINT_POSITION_LIMIT_DAMPER", 0.0) ), position_limit_spring=float( j.axis.limit.stiffness if j.axis is not None and j.axis.limit is not None and j.axis.limit.stiffness is not None else os.environ.get("JAXSIM_JOINT_POSITION_LIMIT_SPRING", 0.0) ), ) for j in sdf_model.joints() if j.type in {"revolute", "continuous", "prismatic", "fixed"} and j.parent != "world" and j.child in links_dict ] # Create a dictionary to find the parent joint of the links. joint_dict = {j.child.name: j.name for j in joints} # Check that all the link poses are expressed wrt their parent joint. for l in sdf_model.links(): if l.name not in links_dict: continue if l.pose is None: continue if l.name == sdf_model.get_canonical_link(): continue if l.name not in joint_dict: raise ValueError(f"Failed to find parent joint of link '{l.name}'") if l.pose.relative_to != joint_dict[l.name]: msg = "Pose of link '{}' is not expressed wrt its parent joint '{}'" raise ValueError(msg.format(l.name, joint_dict[l.name])) # ================ # Parse collisions # ================ # Initialize the collision shapes collisions: list[descriptions.CollisionShape] = [] # Parse the collisions for link in sdf_model.links(): for collision in link.collisions(): if collision.geometry.box is not None: box_collision = utils.create_box_collision( collision=collision, link_description=links_dict[link.name], ) collisions.append(box_collision) continue if collision.geometry.sphere is not None: sphere_collision = utils.create_sphere_collision( collision=collision, link_description=links_dict[link.name], ) collisions.append(sphere_collision) continue if collision.geometry.mesh is not None: if int(os.environ.get("JAXSIM_COLLISION_MESH_ENABLED", "0")): logging.warning("Mesh collision support is still experimental.") mesh_collision = utils.create_mesh_collision( collision=collision, link_description=links_dict[link.name], method=utils.meshes.extract_points_vertices, ) collisions.append(mesh_collision) else: logging.warning( f"Skipping collision shape 'mesh' in link '{link.name}' because mesh collisions are disabled." ) continue # Check any remaining non-None geometry types. for attr_name in collision.geometry.__dict__: if getattr(collision.geometry, attr_name) is not None: logging.warning( f"Skipping collision shape '{attr_name}' in link '{link.name}' as not supported." ) return SDFData( model_name=sdf_model.name, link_descriptions=links, joint_descriptions=joints, frame_descriptions=frames, collision_shapes=collisions, fixed_base=sdf_model.is_fixed_base(), base_link_name=sdf_model.get_canonical_link(), model_pose=model_pose, sdf_model=sdf_model, ) def build_model_description( model_description: pathlib.Path | str | rod.Model, is_urdf: bool | None = None, ) -> descriptions.ModelDescription: """ Build a model description from an SDF/URDF resource. Args: model_description: A path to an SDF/URDF file, a string containing its content, or a pre-parsed/pre-built rod model. is_urdf: Whether the force parsing the resource as a URDF file. Automatically detected if not provided. Returns: The parsed model description. """ # Parse data from the SDF assuming it contains a single model. sdf_data = extract_model_data( model_description=model_description, model_name=None, is_urdf=is_urdf ) # Build the intermediate representation used for building a JaxSim model. # This process, beyond other operations, removes the fixed joints. # Note: if the model is fixed-base, the fixed joint between world and the first # link is removed and the pose of the first link is updated. # # The whole process is: # URDF/SDF ⟶ rod.Model ⟶ ModelDescription ⟶ JaxSimModel. graph = descriptions.ModelDescription.build_model_from( name=sdf_data.model_name, links=sdf_data.link_descriptions, joints=sdf_data.joint_descriptions, frames=sdf_data.frame_descriptions, collisions=sdf_data.collision_shapes, fixed_base=sdf_data.fixed_base, base_link_name=sdf_data.base_link_name, model_pose=sdf_data.model_pose, considered_joints=[ j.name for j in sdf_data.joint_descriptions if j.jtype is not descriptions.JointType.Fixed ], ) # Store the parsed SDF tree as extra info graph = dataclasses.replace(graph, _extra_info={"sdf_model": sdf_data.sdf_model}) return graph ================================================ FILE: src/jaxsim/parsers/rod/utils.py ================================================ import os import pathlib from collections.abc import Callable from typing import TypeVar import numpy as np import numpy.typing as npt import rod import trimesh from rod.utils.resolve_uris import resolve_local_uri import jaxsim.typing as jtp from jaxsim import logging from jaxsim.math import Adjoint, Inertia from jaxsim.parsers import descriptions from jaxsim.parsers.rod import meshes MeshMappingMethod = TypeVar("MeshMappingMethod", bound=Callable[..., npt.NDArray]) def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix: """ Extract the 6D inertia matrix from an SDF inertial element. Args: inertial: The SDF inertial element. Returns: The 6D inertia matrix of the link expressed in the link frame. """ # Extract the "mass" element. m = inertial.mass # Extract the "inertia" element. inertia_element = inertial.inertia ixx = inertia_element.ixx iyy = inertia_element.iyy izz = inertia_element.izz ixy = inertia_element.ixy if inertia_element.ixy is not None else 0.0 ixz = inertia_element.ixz if inertia_element.ixz is not None else 0.0 iyz = inertia_element.iyz if inertia_element.iyz is not None else 0.0 # Build the 3x3 inertia matrix expressed in the CoM. I_CoM = np.array( [ [ixx, ixy, ixz], [ixy, iyy, iyz], [ixz, iyz, izz], ] ) # Build the 6x6 generalized inertia at the CoM. M_CoM = Inertia.to_sixd(mass=m, com=np.zeros(3), I=I_CoM) # Compute the transform from the inertial frame (CoM) to the link frame. L_H_CoM = inertial.pose.transform() if inertial.pose is not None else np.eye(4) # We need its inverse. CoM_X_L = Adjoint.from_transform(transform=L_H_CoM, inverse=True) # Express the CoM inertia matrix in the link frame L. M_L = CoM_X_L.T @ M_CoM @ CoM_X_L return M_L.astype(dtype=float) def joint_to_joint_type(joint: rod.Joint) -> int: """ Extract the joint type from an SDF joint. Args: joint: The parsed SDF joint. Returns: The integer corresponding to the joint type. """ axis = joint.axis joint_type = joint.type if joint_type == "fixed": return descriptions.JointType.Fixed if not (axis.xyz is not None and axis.xyz.xyz is not None): raise ValueError("Failed to read axis xyz data") # Make sure that the axis is a unary vector. axis_xyz = np.array(axis.xyz.xyz).astype(float) axis_xyz = axis_xyz / np.linalg.norm(axis_xyz) if joint_type in {"revolute", "continuous"}: return descriptions.JointType.Revolute if joint_type == "prismatic": return descriptions.JointType.Prismatic raise ValueError("Joint not supported", axis_xyz, joint_type) def create_box_collision( collision: rod.Collision, link_description: descriptions.LinkDescription ) -> descriptions.BoxCollision: """ Create a box collision from an SDF collision element. Args: collision: The SDF collision element. link_description: The link description. Returns: The box collision description. """ x, y, z = collision.geometry.box.size center = np.array([x / 2, y / 2, z / 2]) # Define the bottom corners. bottom_corners = np.array([[0, 0, 0], [x, 0, 0], [x, y, 0], [0, y, 0]]) # Conditionally add the top corners based on the environment variable. top_corners = ( np.array([[0, 0, z], [x, 0, z], [x, y, z], [0, y, z]]) if os.environ.get("JAXSIM_COLLISION_USE_BOTTOM_ONLY", "0").lower() in { "false", "0", } else [] ) # Combine and shift by the center box_corners = np.vstack([bottom_corners, *top_corners]) - center H = collision.pose.transform() if collision.pose is not None else np.eye(4) center_wrt_link = (H @ np.hstack([center, 1.0]))[0:-1] box_corners_wrt_link = ( H @ np.hstack([box_corners, np.vstack([1.0] * box_corners.shape[0])]).T )[0:3, :] collidable_points = [ descriptions.CollidablePoint( parent_link=link_description, position=np.array(corner), enabled=True, ) for corner in box_corners_wrt_link.T ] return descriptions.BoxCollision( collidable_points=collidable_points, center=center_wrt_link ) def create_sphere_collision( collision: rod.Collision, link_description: descriptions.LinkDescription ) -> descriptions.SphereCollision: """ Create a sphere collision from an SDF collision element. Args: collision: The SDF collision element. link_description: The link description. Returns: The sphere collision description. """ # From https://stackoverflow.com/a/26127012 def fibonacci_sphere(samples: int) -> npt.NDArray: # Get the golden ratio in radians. phi = np.pi * (3.0 - np.sqrt(5.0)) # Generate the points. points = [ np.array( [ np.cos(phi * i) * np.sqrt(1 - (y := 1 - 2 * i / (samples - 1)) ** 2), y, np.sin(phi * i) * np.sqrt(1 - y**2), ] ) for i in range(samples) ] # Filter to keep only the bottom half if required. if os.environ.get("JAXSIM_COLLISION_USE_BOTTOM_ONLY", "0").lower() in { "true", "1", }: # Keep only the points with z <= 0. points = [point for point in points if point[2] <= 0] return np.vstack(points) r = collision.geometry.sphere.radius sphere_points = r * fibonacci_sphere( samples=int(os.getenv(key="JAXSIM_COLLISION_SPHERE_POINTS", default="50")) ) H = collision.pose.transform() if collision.pose is not None else np.eye(4) center_wrt_link = (H @ np.hstack([0, 0, 0, 1.0]))[0:-1] sphere_points_wrt_link = ( H @ np.hstack([sphere_points, np.vstack([1.0] * sphere_points.shape[0])]).T )[0:3, :] collidable_points = [ descriptions.CollidablePoint( parent_link=link_description, position=np.array(point), enabled=True, ) for point in sphere_points_wrt_link.T ] return descriptions.SphereCollision( collidable_points=collidable_points, center=center_wrt_link ) def create_mesh_collision( collision: rod.Collision, link_description: descriptions.LinkDescription, method: MeshMappingMethod = None, ) -> descriptions.MeshCollision: """ Create a mesh collision from an SDF collision element. Args: collision: The SDF collision element. link_description: The link description. method: The method to use for mesh wrapping. Returns: The mesh collision description. """ file = pathlib.Path(resolve_local_uri(uri=collision.geometry.mesh.uri)) file_type = file.suffix.replace(".", "") mesh = trimesh.load_mesh(file, file_type=file_type) if mesh.is_empty: raise RuntimeError(f"Failed to process '{file}' with trimesh") mesh.apply_scale(collision.geometry.mesh.scale) logging.info( msg=f"Loading mesh {collision.geometry.mesh.uri} with scale {collision.geometry.mesh.scale}, file type '{file_type}'" ) if method is None: method = meshes.VertexExtraction() logging.debug("Using default Vertex Extraction method for mesh wrapping") else: logging.debug(f"Using method {method} for mesh wrapping") points = method(mesh=mesh) logging.debug(f"Extracted {len(points)} points from mesh") W_H_L = collision.pose.transform() if collision.pose is not None else np.eye(4) # Extract translation from transformation matrix W_p_L = W_H_L[:3, 3] mesh_points_wrt_link = points @ W_H_L[:3, :3].T + W_p_L collidable_points = [ descriptions.CollidablePoint( parent_link=link_description, position=point, enabled=True, ) for point in mesh_points_wrt_link ] return descriptions.MeshCollision(collidable_points=collidable_points, center=W_p_L) def prepare_mesh_for_parametrization( mesh_uri: str, scale: tuple[float, float, float] = (1.0, 1.0, 1.0) ) -> dict: """ Load and prepare a mesh for parametric scaling with exact inertia computation. This function loads a mesh, ensures it's watertight (crucial for volume/inertia calculation), centers it, and returns the data needed for parametric scaling. Args: mesh_uri: URI/path to the mesh file. scale: Initial scale factors to apply (from SDF/URDF). Returns: A dictionary containing: - 'vertices': Centered mesh vertices as numpy array (Nx3) - 'faces': Triangle faces as numpy array (Mx3 integer indices) - 'offset': Original mesh centroid offset as numpy array (3,) - 'uri': The mesh URI for reference - 'is_watertight': Boolean indicating if mesh is watertight - 'volume': The volume of the mesh (after scaling) """ # Load mesh file = pathlib.Path(resolve_local_uri(uri=mesh_uri)) file_type = file.suffix.replace(".", "") mesh = trimesh.load_mesh(file, file_type=file_type) if mesh.is_empty: raise RuntimeError(f"Failed to process '{file}' with trimesh") # Apply initial scale from SDF/URDF mesh.apply_scale(scale) # Check and fix watertightness is_watertight = mesh.is_watertight if not is_watertight: logging.warning( f"Mesh {mesh_uri} is not watertight. Computing convex hull for valid inertia." ) mesh = mesh.convex_hull is_watertight = True # Store original centroid as offset offset = mesh.centroid.copy() # Center the mesh mesh.vertices -= offset return { "vertices": np.array(mesh.vertices, dtype=np.float64), "faces": np.array(mesh.faces, dtype=np.int32), "offset": np.array(offset, dtype=np.float64), "uri": mesh_uri, "is_watertight": is_watertight, "volume": mesh.volume, } ================================================ FILE: src/jaxsim/rbda/__init__.py ================================================ from . import actuation, contacts from .aba import aba from .aba_parallel import aba_parallel from .collidable_points import collidable_points_pos_vel from .crba import crba from .forward_kinematics import forward_kinematics_model from .forward_kinematics_parallel import forward_kinematics_model_parallel from .jacobian import ( jacobian, jacobian_derivative_full_doubly_left, jacobian_full_doubly_left, ) from .kinematic_constraints import compute_constraint_wrenches from .mass_inverse import mass_inverse from .rnea import rnea ================================================ FILE: src/jaxsim/rbda/aba.py ================================================ import jax import jax.numpy as jnp import jaxlie import jaxsim.api as js import jaxsim.typing as jtp from jaxsim.math import STANDARD_GRAVITY, Adjoint, Cross from . import utils def aba( model: js.model.JaxSimModel, *, base_position: jtp.VectorLike, base_quaternion: jtp.VectorLike, joint_positions: jtp.VectorLike, base_linear_velocity: jtp.VectorLike, base_angular_velocity: jtp.VectorLike, joint_velocities: jtp.VectorLike, joint_transforms: jtp.MatrixLike, joint_forces: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, standard_gravity: jtp.FloatLike = STANDARD_GRAVITY, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute forward dynamics using the Articulated Body Algorithm (ABA). Args: model: The model to consider. base_position: The position of the base link. base_quaternion: The quaternion of the base link. joint_positions: The positions of the joints. base_linear_velocity: The linear velocity of the base link in inertial-fixed representation. base_angular_velocity: The angular velocity of the base link in inertial-fixed representation. joint_velocities: The velocities of the joints. joint_transforms: The parent-to-child transforms of the joints. joint_forces: The forces applied to the joints. link_forces: The forces applied to the links expressed in the world frame. standard_gravity: The standard gravity constant. Returns: A tuple containing the base acceleration in inertial-fixed representation and the joint accelerations that result from the applications of the given joint and link forces. Note: The algorithm expects a quaternion with unit norm. """ W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, τ, W_f, W_g = utils.process_inputs( model=model, base_position=base_position, base_quaternion=base_quaternion, joint_positions=joint_positions, base_linear_velocity=base_linear_velocity, base_angular_velocity=base_angular_velocity, joint_velocities=joint_velocities, base_linear_acceleration=None, base_angular_acceleration=None, joint_accelerations=None, joint_forces=joint_forces, link_forces=link_forces, standard_gravity=standard_gravity, ) W_g = jnp.atleast_2d(W_g).T W_v_WB = jnp.atleast_2d(W_v_WB).T # Get the 6D spatial inertia matrices of all links. M = js.model.link_spatial_inertia_matrices(model=model) # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array # Compute the base transform. W_H_B = jaxlie.SE3.from_rotation_and_translation( rotation=jaxlie.SO3(wxyz=W_Q_B), translation=W_p_B, ) # Compute 6D transforms of the base velocity. W_X_B = W_H_B.adjoint() B_X_W = W_H_B.inverse().adjoint() # Extract the parent-to-child adjoints of the joints. i_X_λi = jnp.asarray(joint_transforms) # Extract the joint motion subspaces. S = model.kin_dyn_parameters.motion_subspaces # Allocate buffers. v = jnp.zeros(shape=(model.number_of_links(), 6, 1)) c = jnp.zeros(shape=(model.number_of_links(), 6, 1)) pA = jnp.zeros(shape=(model.number_of_links(), 6, 1)) MA = jnp.zeros(shape=(model.number_of_links(), 6, 6)) # Allocate the buffer of transforms link -> base. i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6)) i_X_0 = i_X_0.at[0].set(jnp.eye(6)) # Initialize base quantities. if model.floating_base(): # Base velocity v₀ in body-fixed representation. v_0 = B_X_W @ W_v_WB v = v.at[0].set(v_0) # Initialize the articulated-body inertia (Mᴬ) of base link. MA_0 = M[0] MA = MA.at[0].set(MA_0) # Initialize the articulated-body bias force (pᴬ) of the base link. pA_0 = Cross.vx_star(v[0]) @ MA[0] @ v[0] - W_X_B.T @ jnp.vstack(W_f[0]) pA = pA.at[0].set(pA_0) # ====== # Pass 1 # ====== Pass1Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix] pass_1_carry: Pass1Carry = (v, c, MA, pA, i_X_0) # Propagate kinematics and initialize AB inertia and AB bias forces. def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]: ii = i - 1 v, c, MA, pA, i_X_0 = carry # Project the joint velocity into its motion subspace. vJ = S[i] * ṡ[ii] # Propagate the link velocity. v_i = i_X_λi[i] @ v[λ[i]] + vJ v = v.at[i].set(v_i) c_i = Cross.vx(v[i]) @ vJ c = c.at[i].set(c_i) # Initialize the articulated-body inertia. MA_i = jnp.array(M[i]) MA = MA.at[i].set(MA_i) # Compute the link-to-base transform. i_Xi_0 = i_X_λi[i] @ i_X_0[λ[i]] i_X_0 = i_X_0.at[i].set(i_Xi_0) # Compute link-to-world transform for the 6D force. i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T # Initialize articulated-body bias force. pA_i = Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(W_f[i]) pA = pA.at[i].set(pA_i) return (v, c, MA, pA, i_X_0), None (v, c, MA, pA, i_X_0), _ = ( jax.lax.scan( f=loop_body_pass1, init=pass_1_carry, xs=jnp.arange(start=1, stop=model.number_of_links()), ) if model.number_of_links() > 1 else [(v, c, MA, pA, i_X_0), None] ) # ====== # Pass 2 # ====== U = jnp.zeros_like(S) d = jnp.zeros(shape=(model.number_of_links(), 1)) u = jnp.zeros(shape=(model.number_of_links(), 1)) Pass2Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix] pass_2_carry: Pass2Carry = (U, d, u, MA, pA) def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]: ii = i - 1 U, d, u, MA, pA = carry U_i = MA[i] @ S[i] U = U.at[i].set(U_i) d_i = S[i].T @ U[i] d = d.at[i].set(d_i.squeeze()) u_i = τ[ii] - S[i].T @ pA[i] u = u.at[i].set(u_i.squeeze()) # Compute the articulated-body inertia and bias force of this link. Ma = MA[i] - U[i] / d[i] @ U[i].T pa = pA[i] + Ma @ c[i] + U[i] * (u[i] / d[i]) # Propagate them to the parent, handling the base link. def propagate( MA_pA: tuple[jtp.Matrix, jtp.Matrix], ) -> tuple[jtp.Matrix, jtp.Matrix]: MA, pA = MA_pA MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i] MA = MA.at[λ[i]].set(MA_λi) pA_λi = pA[λ[i]] + i_X_λi[i].T @ pa pA = pA.at[λ[i]].set(pA_λi) return MA, pA MA, pA = jax.lax.cond( pred=jnp.logical_or(λ[i] != 0, model.floating_base()), true_fun=propagate, false_fun=lambda MA_pA: MA_pA, operand=(MA, pA), ) return (U, d, u, MA, pA), None (U, d, u, MA, pA), _ = ( jax.lax.scan( f=loop_body_pass2, init=pass_2_carry, xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())), ) if model.number_of_links() > 1 else [(U, d, u, MA, pA), None] ) # ====== # Pass 3 # ====== if model.floating_base(): a0 = jnp.linalg.solve(-MA[0], pA[0]) else: a0 = -B_X_W @ W_g s̈ = jnp.zeros_like(s) a = jnp.zeros_like(v).at[0].set(a0) Pass3Carry = tuple[jtp.Matrix, jtp.Vector] pass_3_carry = (a, s̈) def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]: ii = i - 1 a, s̈ = carry # Propagate the link acceleration. a_i = i_X_λi[i] @ a[λ[i]] + c[i] # Compute the joint acceleration. s̈_ii = (u[i] - U[i].T @ a_i) / d[i] s̈ = s̈.at[ii].set(s̈_ii.squeeze()) # Sum the joint acceleration to the parent link acceleration. a_i = a_i + S[i] * s̈[ii] a = a.at[i].set(a_i) return (a, s̈), None (a, s̈), _ = ( jax.lax.scan( f=loop_body_pass3, init=pass_3_carry, xs=jnp.arange(1, model.number_of_links()), ) if model.number_of_links() > 1 else [(a, s̈), None] ) # ============== # Adjust outputs # ============== # TODO: remove vstack and shape=(6, 1)? if model.floating_base(): # Convert the base acceleration to inertial-fixed representation, # and add gravity. B_a_WB = a[0] W_a_WB = W_X_B @ B_a_WB + W_g else: W_a_WB = jnp.zeros(6) return W_a_WB.squeeze(), jnp.atleast_1d(s̈.squeeze()) ================================================ FILE: src/jaxsim/rbda/aba_parallel.py ================================================ import math import jax import jax.numpy as jnp import jaxlie import jaxsim.api as js import jaxsim.typing as jtp from jaxsim.math import STANDARD_GRAVITY, Adjoint, Cross from . import utils def aba_parallel( model: js.model.JaxSimModel, *, base_position: jtp.VectorLike, base_quaternion: jtp.VectorLike, joint_positions: jtp.VectorLike, base_linear_velocity: jtp.VectorLike, base_angular_velocity: jtp.VectorLike, joint_velocities: jtp.VectorLike, joint_transforms: jtp.MatrixLike, joint_forces: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, standard_gravity: jtp.FloatLike = STANDARD_GRAVITY, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute forward dynamics using a hybrid parallel ABA. Passes 1 and 3 use pointer jumping in O(log D) parallel steps. Pass 2 uses level-parallel processing in O(D) steps because the backward inertia accumulation is not associative. The interface and semantics are identical to :func:`aba`, but passes 1 and 3 are parallelized via pointer jumping. """ W_p_B, W_Q_B, _, W_v_WB, ṡ, _, _, τ, W_f, W_g = utils.process_inputs( model=model, base_position=base_position, base_quaternion=base_quaternion, joint_positions=joint_positions, base_linear_velocity=base_linear_velocity, base_angular_velocity=base_angular_velocity, joint_velocities=joint_velocities, base_linear_acceleration=None, base_angular_acceleration=None, joint_accelerations=None, joint_forces=joint_forces, link_forces=link_forces, standard_gravity=standard_gravity, ) W_g = jnp.atleast_2d(W_g).T W_v_WB = jnp.atleast_2d(W_v_WB).T # Get the 6D spatial inertia matrices of all links. M = js.model.link_spatial_inertia_matrices(model=model) # Get the parent array λ(i). λ = model.kin_dyn_parameters.parent_array # Get the tree level structure for level-parallel processing. level_nodes = jnp.asarray(model.kin_dyn_parameters.level_nodes) level_mask = jnp.asarray(model.kin_dyn_parameters.level_mask) n_levels = level_nodes.shape[0] # Compute the base transform. W_H_B = jaxlie.SE3.from_rotation_and_translation( rotation=jaxlie.SO3(wxyz=W_Q_B), translation=W_p_B, ) # Compute 6D transforms of the base velocity. W_X_B = W_H_B.adjoint() B_X_W = W_H_B.inverse().adjoint() # Extract the parent-to-child adjoints of the joints. i_X_λi = jnp.asarray(joint_transforms) # Extract the joint motion subspaces. S = model.kin_dyn_parameters.motion_subspaces n = model.number_of_links() # Parent array with root self-loop. # Note: λ(0) is set to 0 to enable root self-referencing. ptr0 = jnp.asarray(λ).at[0].set(0) # Number of pointer-jumping rounds. n_rounds = max(1, math.ceil(math.log2(max(n_levels, 2)))) # ====== # Pass 1 # ====== # Two coupled affine recurrences propagated via pointer jumping: # v_i = i_X_λi[i] @ v_parent + vJ_i # T_i = i_X_λi[i] @ T_parent # # Associative operator on (A, b, T): # compose(parent, child) = (A @ A_p, A @ b_p + b, A @ T_p) # Local transforms and joint velocities. ṡ_col = jnp.atleast_1d(ṡ).reshape(-1, 1) # (n_joints, 1) ṡ_padded = jnp.concatenate([jnp.zeros((1, 1)), ṡ_col]) # (n, 1) vJ = S * ṡ_padded[:, :, None] # (n, 6, 1) # Initialize pointer-jumping state for each node. A = i_X_λi.copy() # (n, 6, 6) b = vJ.copy() # (n, 6, 1) T = i_X_λi.copy() # (n, 6, 6) # Root initial values. if model.floating_base(): v_0 = B_X_W @ W_v_WB A = A.at[0].set(jnp.eye(6)) b = b.at[0].set(v_0) T = T.at[0].set(jnp.eye(6)) else: A = A.at[0].set(jnp.eye(6)) b = b.at[0].set(jnp.zeros((6, 1))) T = T.at[0].set(jnp.eye(6)) ptr = ptr0.copy() done = jnp.arange(n) == 0 def _pass1_jump(carry, _): A, b, T, ptr, done = carry need = ~done A_par = A[ptr] b_par = b[ptr] T_par = T[ptr] # Associative compose. A_new = jnp.where(need[:, None, None], A @ A_par, A) b_new = jnp.where(need[:, None, None], A @ b_par + b, b) T_new = jnp.where(need[:, None, None], A @ T_par, T) ptr_new = jnp.where(need, ptr[ptr], ptr) done_new = done | done[ptr] return (A_new, b_new, T_new, ptr_new, done_new), None (_, v, i_X_0, _, _), _ = ( jax.lax.scan( f=_pass1_jump, init=(A, b, T, ptr, done), xs=jnp.arange(n_rounds), ) if n > 1 else ((A, b, T, ptr, done), None) ) # v now contains the 6D body velocity of every link. # i_X_0 contains the body-to-base transform for every link. # Compute c, MA, pA for all nodes in parallel. def _init_node(node_i): vJ_i = S[node_i] * ṡ_padded[node_i] c_i = Cross.vx(v[node_i]) @ vJ_i MA_i = M[node_i] i_Xf_W = Adjoint.inverse(i_X_0[node_i] @ B_X_W).T pA_i = Cross.vx_star(v[node_i]) @ M[node_i] @ v[node_i] - i_Xf_W @ jnp.vstack( W_f[node_i] ) return c_i, MA_i, pA_i c, MA, pA = jax.vmap(_init_node)(jnp.arange(n)) # Override base MA and pA if floating base. if model.floating_base(): MA = MA.at[0].set(M[0]) pA_0 = Cross.vx_star(v[0]) @ M[0] @ v[0] - W_X_B.T @ jnp.vstack(W_f[0]) pA = pA.at[0].set(pA_0) # ====== # Pass 2 # ====== # The Schur complement and multi-child scatter-add make this pass # non-associative, so it remains level-parallel. U = jnp.zeros_like(S) d = jnp.ones(shape=(n, 1)) # Ones to avoid NaN for the base node. u = jnp.zeros(shape=(n, 1)) def _masked_scatter_add(arr, indices, values, m): """Add values[j] to arr[indices[j]] only where m[j] is True.""" mask = jnp.reshape(m, m.shape + (1,) * (values.ndim - 1)) masked_values = jnp.where(mask, values, jnp.zeros_like(values)) return arr.at[indices].add(masked_values) def _pass2_level(carry, level_idx): U, d, u, MA, pA = carry actual_level = n_levels - 1 - level_idx nodes = level_nodes[actual_level] mask = level_mask[actual_level] def _process_node_pass2(node_i): # Clamp index to avoid out-of-bounds for padded entries. ii = jnp.maximum(node_i, 1) - 1 parent = λ[node_i] U_i = MA[node_i] @ S[node_i] d_i = (S[node_i].T @ U_i).squeeze() u_i = (τ[ii] - S[node_i].T @ pA[node_i]).squeeze() Ma_i = MA[node_i] - U_i / d_i @ U_i.T pa_i = pA[node_i] + Ma_i @ c[node_i] + U_i * (u_i / d_i) Ma_parent = i_X_λi[node_i].T @ Ma_i @ i_X_λi[node_i] pa_parent = i_X_λi[node_i].T @ pa_i return U_i, d_i, u_i, Ma_parent, pa_parent, parent U_lev, d_lev, u_lev, Ma_par, pa_par, parents = jax.vmap(_process_node_pass2)( nodes ) mask_6x1 = mask[:, None, None] mask_1 = mask[:, None] U = carry[0].at[nodes].set(jnp.where(mask_6x1, U_lev, carry[0][nodes])) d = carry[1].at[nodes].set(jnp.where(mask_1, d_lev[:, None], carry[1][nodes])) u = carry[2].at[nodes].set(jnp.where(mask_1, u_lev[:, None], carry[2][nodes])) should_propagate = jnp.where( model.floating_base(), mask, jnp.logical_and(mask, parents != 0), ) MA = _masked_scatter_add(carry[3], parents, Ma_par, should_propagate) pA = _masked_scatter_add(carry[4], parents, pa_par, should_propagate) return (U, d, u, MA, pA), None n_backward_levels = n_levels - 1 (U, d, u, MA, pA), _ = ( jax.lax.scan( f=_pass2_level, init=(U, d, u, MA, pA), xs=jnp.arange(n_backward_levels), ) if n_backward_levels > 0 else ((U, d, u, MA, pA), None) ) # ====== # Pass 3 # ====== # The acceleration recurrence is an affine recurrence: # a_i = P_i @ i_X_λi[i] @ a_parent + P_i @ c_i + S_i * u_i / d_i # where P_i = I - S_i @ U_i^T / d_i is the 6x6 projection matrix. if model.floating_base(): a0 = jnp.linalg.solve(-MA[0], pA[0]) else: a0 = -B_X_W @ W_g # Pre-compute the affine recurrence coefficients for all nodes. def _init_pass3(node_i): P_i = jnp.eye(6) - S[node_i] @ U[node_i].T / d[node_i] A_i = P_i @ i_X_λi[node_i] b_i = P_i @ c[node_i] + S[node_i] * (u[node_i] / d[node_i]) return A_i, b_i A, b = jax.vmap(_init_pass3)(jnp.arange(n)) # Root acceleration is known. A = A.at[0].set(jnp.eye(6)) b = b.at[0].set(a0) # Pointer jumping for the affine recurrence. ptr = ptr0.copy() done = jnp.arange(n) == 0 def _pass3_jump(carry, _): A, b, ptr, done = carry need = ~done A_par = A[ptr] b_par = b[ptr] # Associative compose. A_new = jnp.where(need[:, None, None], A @ A_par, A) b_new = jnp.where(need[:, None, None], A @ b_par + b, b) ptr_new = jnp.where(need, ptr[ptr], ptr) done_new = done | done[ptr] return (A_new, b_new, ptr_new, done_new), None (_, a, _, _), _ = ( jax.lax.scan( f=_pass3_jump, init=(A, b, ptr, done), xs=jnp.arange(n_rounds), ) if n > 1 else ((A, b, ptr, done), None) ) # Recover joint accelerations: s̈_i = (u_i - U_i^T @ a_before_i) / d_i # where a_before_i = i_X_λi[i] @ a_parent + c_i. a_λi = a[ptr0] a_before = i_X_λi @ a_λi + c Ut_a = (U.transpose(0, 2, 1) @ a_before).squeeze(-1) # (n, 1) s̈ = (u - Ut_a) / d # (n, 1) # ============== # Adjust outputs # ============== if model.floating_base(): B_a_WB = a[0] W_a_WB = W_X_B @ B_a_WB + W_g else: W_a_WB = jnp.zeros(6) # Joint accelerations: skip base index, take indices 1..n-1. s̈_out = s̈[1:] return W_a_WB.squeeze(), jnp.atleast_1d(s̈_out.squeeze()) ================================================ FILE: src/jaxsim/rbda/actuation/__init__.py ================================================ from .common import ActuationParams ================================================ FILE: src/jaxsim/rbda/actuation/common.py ================================================ import dataclasses import jax_dataclasses from jax_dataclasses import Static import jaxsim.typing as jtp from jaxsim.utils import JaxsimDataclass @jax_dataclasses.pytree_dataclass class ActuationParams(JaxsimDataclass): """ Parameters class for the actuation model. """ torque_max: jtp.Float = dataclasses.field(default=3000.0) # (Nm) omega_th: jtp.Float = dataclasses.field(default=30.0) # (rad/s) omega_max: jtp.Float = dataclasses.field(default=100.0) # (rad/s) enable_friction: Static[bool] = dataclasses.field(default=True) ================================================ FILE: src/jaxsim/rbda/collidable_points.py ================================================ import jax import jax.numpy as jnp import jaxsim.api as js import jaxsim.typing as jtp from jaxsim.math import Skew def collidable_points_pos_vel( model: js.model.JaxSimModel, *, link_transforms: jtp.Matrix, link_velocities: jtp.Matrix, ) -> tuple[jtp.Matrix, jtp.Matrix]: """ Compute the position and linear velocity of the enabled collidable points in the world frame. Args: model: The model to consider. link_transforms: The transforms from the world frame to each link. link_velocities: The linear and angular velocities of each link. Returns: A tuple containing the position and linear velocity of the enabled collidable points. """ # Get the indices of the enabled collidable points. indices_of_enabled_collidable_points = ( model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points ) parent_link_idx_of_enabled_collidable_points = jnp.array( model.kin_dyn_parameters.contact_parameters.body, dtype=int )[indices_of_enabled_collidable_points] L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[ indices_of_enabled_collidable_points ] if len(indices_of_enabled_collidable_points) == 0: return jnp.array(0).astype(float), jnp.empty(0).astype(float) def process_point_kinematics( Li_p_C: jtp.Vector, parent_body: jtp.Int ) -> tuple[jtp.Vector, jtp.Vector]: # Compute the position of the collidable point. W_p_Ci = (link_transforms[parent_body] @ jnp.hstack([Li_p_C, 1]))[0:3] # Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}. CW_vl_WCi = ( jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()]) @ link_velocities[parent_body].squeeze() ) return W_p_Ci, CW_vl_WCi # Process all the collidable points in parallel. W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)( L_p_Ci, parent_link_idx_of_enabled_collidable_points, ) return W_p_Ci, CW_vl_WC ================================================ FILE: src/jaxsim/rbda/contacts/__init__.py ================================================ from . import relaxed_rigid, rigid, soft from .common import ContactModel, ContactsParams from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams from .rigid import RigidContacts, RigidContactsParams from .soft import SoftContacts, SoftContactsParams ContactParamsTypes = ( SoftContactsParams | RigidContactsParams | RelaxedRigidContactsParams ) ================================================ FILE: src/jaxsim/rbda/contacts/common.py ================================================ from __future__ import annotations import abc import functools import jax import jax.numpy as jnp import jaxsim.api as js import jaxsim.terrain import jaxsim.typing as jtp from jaxsim.math import STANDARD_GRAVITY from jaxsim.utils import JaxsimDataclass try: from typing import Self except ImportError: from typing_extensions import Self MAX_STIFFNESS = 1e6 MAX_DAMPING = 1e4 @functools.partial(jax.jit, static_argnames=("terrain",)) def compute_penetration_data( p: jtp.VectorLike, v: jtp.VectorLike, terrain: jaxsim.terrain.Terrain, ) -> tuple[jtp.Float, jtp.Float, jtp.Vector]: """ Compute the penetration data (depth, rate, and terrain normal) of a collidable point. Args: p: The position of the collidable point. v: The linear velocity of the point (linear component of the mixed 6D velocity of the implicit frame `C = (W_p_C, [W])` associated to the point). terrain: The considered terrain. Returns: A tuple containing the penetration depth, the penetration velocity, and the considered terrain normal. """ # Pre-process the position and the linear velocity of the collidable point. W_ṗ_C = jnp.array(v).squeeze() px, py, pz = jnp.array(p).squeeze() # Compute the terrain normal and the contact depth. n̂ = terrain.normal(x=px, y=py).squeeze() h = jnp.array([0, 0, terrain.height(x=px, y=py) - pz]) # Compute the penetration depth normal to the terrain. δ = jnp.maximum(0.0, jnp.dot(h, n̂)) # Compute the penetration normal velocity. δ_dot = -jnp.dot(W_ṗ_C, n̂) # Enforce the penetration rate to be zero when the penetration depth is zero. δ_dot = jnp.where(δ > 0, δ_dot, 0.0) return δ, δ_dot, n̂ class ContactsParams(JaxsimDataclass): """ Abstract class representing the parameters of a contact model. Note: This class is supposed to store only the tunable parameters of the contact model, i.e. all those parameters that can be changed during runtime. If the contact model has also static parameters, they should be stored in the corresponding `ContactModel` class. """ @classmethod @abc.abstractmethod def build(cls: type[Self], **kwargs) -> Self: """ Create a `ContactsParams` instance with specified parameters. Returns: The `ContactsParams` instance. """ pass def build_default_from_jaxsim_model( self: type[Self], model: js.model.JaxSimModel, *, stiffness: jtp.FloatLike | None = None, damping: jtp.FloatLike | None = None, standard_gravity: jtp.FloatLike = STANDARD_GRAVITY, static_friction_coefficient: jtp.FloatLike = 0.5, max_penetration: jtp.FloatLike = 0.001, number_of_active_collidable_points_steady_state: jtp.IntLike = 1, damping_ratio: jtp.FloatLike = 1.0, p: jtp.FloatLike = 0.5, q: jtp.FloatLike = 0.5, **kwargs, ) -> Self: """ Create a `ContactsParams` instance with default parameters. Args: model: The robot model considered by the contact model. stiffness: The stiffness of the contact model. damping: The damping of the contact model. standard_gravity: The standard gravity acceleration. static_friction_coefficient: The static friction coefficient. max_penetration: The maximum penetration depth. number_of_active_collidable_points_steady_state: The number of active collidable points in steady state. damping_ratio: The damping ratio. p: The first parameter of the contact model. q: The second parameter of the contact model. **kwargs: Optional additional arguments. Returns: The `ContactsParams` instance. Note: The `stiffness` is intended as the terrain stiffness in the Soft Contacts model, while it is the Baumgarte stabilization stiffness in the Rigid Contacts model. The `damping` is intended as the terrain damping in the Soft Contacts model, while it is the Baumgarte stabilization damping in the Rigid Contacts model. The `damping_ratio` parameter allows to operate on the following conditions: - ξ > 1.0: over-damped - ξ = 1.0: critically damped - ξ < 1.0: under-damped """ # Use symbols for input parameters. ξ = damping_ratio δ_max = max_penetration μc = static_friction_coefficient nc = number_of_active_collidable_points_steady_state # Compute the total mass of the model. m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum() # Compute the stiffness to get the desired steady-state penetration. # Note that this is dependent on the non-linear exponent used in # the damping term of the Hunt/Crossley model. if stiffness is None: # Compute the average support force on each collidable point. f_average = m * standard_gravity / nc stiffness = f_average / jnp.power(δ_max, 1 + p) stiffness = jnp.clip(stiffness, 0, MAX_STIFFNESS) # Compute the damping using the damping ratio. critical_damping = 2 * jnp.sqrt(stiffness * m) if damping is None: damping = ξ * critical_damping damping = jnp.clip(damping, 0, MAX_DAMPING) return self.build( K=stiffness, D=damping, mu=μc, p=p, q=q, **kwargs, ) @abc.abstractmethod def valid(self, **kwargs) -> jtp.BoolLike: """ Check if the parameters are valid. Returns: True if the parameters are valid, False otherwise. """ pass class ContactModel(JaxsimDataclass): """ Abstract class representing a contact model. """ @classmethod @abc.abstractmethod def build( cls: type[Self], **kwargs, ) -> Self: """ Create a `ContactModel` instance with specified parameters. Returns: The `ContactModel` instance. """ pass @abc.abstractmethod def compute_contact_forces( self, model: js.model.JaxSimModel, data: js.data.JaxSimModelData, **kwargs, ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: """ Compute the contact forces. Args: model: The robot model considered by the contact model. data: The data of the considered model. **kwargs: Optional additional arguments, specific to the contact model. Returns: A tuple containing as first element the computed 6D contact force applied to the contact points and expressed in the world frame, and as second element a dictionary of optional additional information. """ pass @classmethod def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]: """ Build zero state variables of the contact model. Args: model: The robot model considered by the contact model. Note: There are contact models that require to extend the state vector of the integrated ODE system with additional variables. Our integrators are capable of operating on a generic state, as long as it is a PyTree. This method builds the zero state variables of the contact model as a dictionary of JAX arrays. Returns: A dictionary storing the zero state variables of the contact model. """ return {} @property def _parameters_class(self) -> type[ContactsParams]: """ Return the class of the contact parameters. Returns: The class of the contact parameters. """ import importlib return getattr( importlib.import_module("jaxsim.rbda.contacts"), ( self.__name__ + "Params" if isinstance(self, type) else self.__class__.__name__ + "Params" ), ) @abc.abstractmethod def update_contact_state( self: type[Self], old_contact_state: dict[str, jtp.Array] ) -> dict[str, jtp.Array]: """ Update the contact state. Args: old_contact_state: The old contact state. Returns: The updated contact state. """ @abc.abstractmethod def update_velocity_after_impact( self: type[Self], model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> js.data.JaxSimModelData: """ Update the velocity after an impact. Args: model: The robot model considered by the contact model. data: The data of the considered model. Returns: The updated data of the considered model. """ ================================================ FILE: src/jaxsim/rbda/contacts/relaxed_rigid.py ================================================ from __future__ import annotations import dataclasses from collections.abc import Callable from typing import Any import jax import jax.numpy as jnp import jax_dataclasses import optax from optax.tree_utils import tree_get import jaxsim.api as js import jaxsim.typing as jtp from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr from . import common, soft try: from typing import Self except ImportError: from typing_extensions import Self try: from optax.tree_utils import tree_norm except ImportError: from optax.tree_utils import tree_l2_norm as tree_norm @jax_dataclasses.pytree_dataclass class RelaxedRigidContactsParams(common.ContactsParams): """Parameters of the relaxed rigid contacts model.""" # Time constant time_constant: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(0.02, dtype=float) ) # Adimensional damping coefficient damping_coefficient: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(1.0, dtype=float) ) # Minimum impedance d_min: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(0.9, dtype=float) ) # Maximum impedance d_max: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(0.95, dtype=float) ) # Width width: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(0.001, dtype=float) ) # Midpoint midpoint: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(0.5, dtype=float) ) # Power exponent power: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(2.0, dtype=float) ) # Stiffness K: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(0.0, dtype=float) ) # Damping D: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(0.0, dtype=float) ) # Friction coefficient mu: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(0.005, dtype=float) ) def __hash__(self) -> int: from jaxsim.utils.wrappers import HashedNumpyArray return hash( ( HashedNumpyArray(self.time_constant), HashedNumpyArray(self.damping_coefficient), HashedNumpyArray(self.d_min), HashedNumpyArray(self.d_max), HashedNumpyArray(self.width), HashedNumpyArray(self.midpoint), HashedNumpyArray(self.power), HashedNumpyArray(self.K), HashedNumpyArray(self.D), HashedNumpyArray(self.mu), ) ) def __eq__(self, other: RelaxedRigidContactsParams) -> bool: if not isinstance(other, RelaxedRigidContactsParams): return False return hash(self) == hash(other) @classmethod def build( cls: type[Self], *, time_constant: jtp.FloatLike | None = None, damping_coefficient: jtp.FloatLike | None = None, d_min: jtp.FloatLike | None = None, d_max: jtp.FloatLike | None = None, width: jtp.FloatLike | None = None, midpoint: jtp.FloatLike | None = None, power: jtp.FloatLike | None = None, K: jtp.FloatLike | None = None, D: jtp.FloatLike | None = None, mu: jtp.FloatLike | None = None, **kwargs, ) -> Self: """Create a `RelaxedRigidContactsParams` instance.""" def default(name: str): return cls.__dataclass_fields__[name].default_factory() return cls( time_constant=jnp.array( ( time_constant if time_constant is not None else default("time_constant") ), dtype=float, ), damping_coefficient=jnp.array( ( damping_coefficient if damping_coefficient is not None else default("damping_coefficient") ), dtype=float, ), d_min=jnp.array( d_min if d_min is not None else default("d_min"), dtype=float ), d_max=jnp.array( d_max if d_max is not None else default("d_max"), dtype=float ), width=jnp.array( width if width is not None else default("width"), dtype=float ), midpoint=jnp.array( midpoint if midpoint is not None else default("midpoint"), dtype=float ), power=jnp.array( power if power is not None else default("power"), dtype=float ), K=jnp.array( K if K is not None else default("K"), dtype=float, ), D=jnp.array(D if D is not None else default("D"), dtype=float), mu=jnp.array(mu if mu is not None else default("mu"), dtype=float), ) def valid(self) -> jtp.BoolLike: """Check if the parameters are valid.""" return bool( jnp.all(self.time_constant >= 0.0) and jnp.all(self.damping_coefficient > 0.0) and jnp.all(self.d_min >= 0.0) and jnp.all(self.d_max <= 1.0) and jnp.all(self.d_min <= self.d_max) and jnp.all(self.width >= 0.0) and jnp.all(self.midpoint >= 0.0) and jnp.all(self.power >= 0.0) and jnp.all(self.mu >= 0.0) ) @jax_dataclasses.pytree_dataclass class RelaxedRigidContacts(common.ContactModel): """Relaxed rigid contacts model.""" _solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field( default=("tol", "maxiter", "memory_size", "scale_init_precond"), kw_only=True ) _solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field( default=(1e-6, 50, 10, False), kw_only=True ) @property def solver_options(self) -> dict[str, Any]: """Get the solver options.""" return dict( zip( self._solver_options_keys, self._solver_options_values, strict=True, ) ) @classmethod def build( cls: type[Self], solver_options: dict[str, Any] | None = None, **kwargs, ) -> Self: """ Create a `RelaxedRigidContacts` instance with specified parameters. Args: solver_options: The options to pass to the L-BFGS solver. **kwargs: The parameters of the relaxed rigid contacts model. Returns: The `RelaxedRigidContacts` instance. """ # Get the default solver options. default_solver_options = dict( zip(cls._solver_options_keys, cls._solver_options_values, strict=True) ) # Create the solver options to set by combining the default solver options # with the user-provided solver options. solver_options = default_solver_options | ( solver_options if solver_options is not None else {} ) # Make sure that the solver options are hashable. # We need to check this because the solver options are static. try: hash(tuple(solver_options.values())) except TypeError as exc: raise ValueError( "The values of the solver options must be hashable." ) from exc return cls( _solver_options_keys=tuple(solver_options.keys()), _solver_options_values=tuple(solver_options.values()), **kwargs, ) def update_contact_state( self: type[Self], old_contact_state: dict[str, jtp.Array] ) -> dict[str, jtp.Array]: """ Update the contact state. Args: old_contact_state: The old contact state. Returns: The updated contact state. """ return {} def update_velocity_after_impact( self: type[Self], model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> js.data.JaxSimModelData: """ Update the velocity after an impact. Args: model: The robot model considered by the contact model. data: The data of the considered model. Returns: The updated data of the considered model. """ return data @jax.jit def compute_contact_forces( self, model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: """ Compute the contact forces. Args: model: The model to consider. data: The data of the considered model. link_forces: Optional `(n_links, 6)` matrix of external forces acting on the links, expressed in the same representation of data. joint_force_references: Optional `(n_joints,)` vector of joint forces. Returns: A tuple containing as first element the computed contact forces in inertial representation. """ link_forces = jnp.atleast_2d( jnp.array(link_forces, dtype=float).squeeze() if link_forces is not None else jnp.zeros((model.number_of_links(), 6)) ) joint_force_references = jnp.atleast_1d( jnp.array(joint_force_references, dtype=float).squeeze() if joint_force_references is not None else jnp.zeros(model.number_of_joints()) ) references = js.references.JaxSimModelReferences.build( model=model, data=data, velocity_representation=data.velocity_representation, link_forces=link_forces, joint_force_references=joint_force_references, ) # Compute the position and linear velocities (mixed representation) of # all collidable points belonging to the robot. position, velocity = js.contact.collidable_point_kinematics( model=model, data=data ) # Compute the penetration depth and velocity of the collidable points. # Note that this function considers the penetration in the normal direction. δ, _, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))( position, velocity, model.terrain ) # Compute the position in the constraint frame. position_constraint = jax.vmap(lambda δ, n̂: -δ * n̂)(δ, n̂) # Compute the regularization terms. a_ref, r, *_ = self._regularizers( model=model, position_constraint=position_constraint, velocity_constraint=velocity, parameters=model.contact_params, ) # Compute the transforms of the implicit frames corresponding to the # collidable points. W_H_C = js.contact.transforms(model=model, data=data) with ( data.switch_velocity_representation(VelRepr.Mixed), references.switch_velocity_representation(VelRepr.Mixed), ): BW_ν = data.generalized_velocity BW_ν̇_free = jnp.hstack( js.model.forward_dynamics_aba( model=model, data=data, link_forces=references.link_forces(model=model, data=data), joint_forces=references.joint_force_references(model=model), ) ) M_inv = js.model.free_floating_mass_matrix_inverse(model=model, data=data) # Compute the linear part of the Jacobian of the collidable points Jl_WC = jnp.vstack( jax.vmap(lambda J, δ: J * (δ > 0))( js.contact.jacobian(model=model, data=data)[:, :3, :], δ ) ) # Compute the linear part of the Jacobian derivative of the collidable points J̇l_WC = jnp.vstack( jax.vmap(lambda J̇, δ: J̇ * (δ > 0))( js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ ), ) # Compute the Delassus matrix directly using J and J̇. G_contacts = Jl_WC @ M_inv @ Jl_WC.T # Compute the free mixed linear acceleration of the collidable points. CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇l_WC @ BW_ν # Calculate quantities for the linear optimization problem. R = jnp.diag(r) A = G_contacts + R b = CW_al_free_WC - a_ref # Create the objective function to minimize as a lambda computing the cost # from the optimized variables x. objective = lambda x, A, b: jnp.sum(jnp.square(A @ x + b)) # ======================================== # Helper function to run the L-BFGS solver # ======================================== def run_optimization( init_params: jtp.Vector, fun: Callable, opt: optax.GradientTransformationExtraArgs, maxiter: int, tol: float, ) -> tuple[jtp.Vector, optax.OptState]: # Get the function to compute the loss and the gradient w.r.t. its inputs. value_and_grad_fn = optax.value_and_grad_from_state(fun) # Initialize the carry of the following loop. OptimizationCarry = tuple[jtp.Vector, optax.OptState] init_carry: OptimizationCarry = (init_params, opt.init(params=init_params)) def step(carry: OptimizationCarry) -> OptimizationCarry: params, state = carry value, grad = value_and_grad_fn( params, state=state, A=A, b=b, ) updates, state = opt.update( updates=grad, state=state, params=params, value=value, grad=grad, value_fn=fun, A=A, b=b, ) params = optax.apply_updates(params, updates) return params, state # TODO: maybe fix the number of iterations and switch to scan? def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool: _, state = carry iter_num = tree_get(state, "count") grad = tree_get(state, "grad") err = tree_norm(grad) return (iter_num == 0) | ((iter_num < maxiter) & (err >= tol)) final_params, final_state = jax.lax.while_loop( continuing_criterion, step, init_carry ) return final_params, final_state # ====================================== # Compute the contact forces with L-BFGS # ====================================== # Initialize the optimized forces with a linear Hunt/Crossley model. init_params = jax.vmap( lambda p, v: soft.SoftContacts.hunt_crossley_contact_model( position=p, velocity=v, terrain=model.terrain, K=1e6, D=2e3, p=0.5, q=0.5, # No tangential initial forces. mu=0.0, tangential_deformation=jnp.zeros(3), )[0] )(position, velocity).flatten() # Get the solver options. solver_options = self.solver_options # Extract the options corresponding to the convergence criteria. # All the remaining options are passed to the solver. tol = solver_options.pop("tol") maxiter = solver_options.pop("maxiter") solve_fn = lambda *_: run_optimization( init_params=init_params, fun=objective, opt=optax.lbfgs(**solver_options), tol=tol, maxiter=maxiter, ) # Compute the 3D linear force in C[W] frame. solution, _ = jax.lax.custom_linear_solve( lambda x: A @ x, -b, solve=solve_fn, symmetric=True, has_aux=True, ) # Reshape the optimized solution to be a matrix of 3D contact forces. CW_fl_C = solution.reshape(-1, 3) # Convert the contact forces from mixed to inertial-fixed representation. W_f_C = jax.vmap( lambda CW_fl_C, W_H_C: ( ModelDataWithVelocityRepresentation.other_representation_to_inertial( array=jnp.zeros(6).at[0:3].set(CW_fl_C), transform=W_H_C, other_representation=VelRepr.Mixed, is_force=True, ) ), )(CW_fl_C, W_H_C) return W_f_C, {} @staticmethod def _regularizers( model: js.model.JaxSimModel, position_constraint: jtp.Vector, velocity_constraint: jtp.Vector, parameters: RelaxedRigidContactsParams, ) -> tuple: """ Compute the contact jacobian and the reference acceleration. Args: model: The jaxsim model. position_constraint: The position of the collidable points in the constraint frame. velocity_constraint: The velocity of the collidable points in the constraint frame. parameters: The parameters of the relaxed rigid contacts model. Returns: A tuple containing the reference acceleration, the regularization matrix, the stiffness, and the damping. """ # Extract the parameters of the contact model. Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ = ( getattr(parameters, field) for field in ( "time_constant", "damping_coefficient", "d_min", "d_max", "width", "midpoint", "power", "K", "D", "mu", ) ) # Get the indices of the enabled collidable points. indices_of_enabled_collidable_points = ( model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points ) parent_link_idx_of_enabled_collidable_points = jnp.array( model.kin_dyn_parameters.contact_parameters.body, dtype=int )[indices_of_enabled_collidable_points] # Compute the 6D inertia matrices of all links. M_L = js.model.link_spatial_inertia_matrices(model=model) def imp_aref( pos: jtp.Vector, vel: jtp.Vector, ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector]: """ Calculate impedance and offset acceleration in constraint frame. Args: pos: position in constraint frame. vel: velocity in constraint frame. Returns: ξ: computed impedance a_ref: offset acceleration in constraint frame K: computed stiffness D: computed damping """ imp_x = jnp.abs(pos) / width imp_a = (1.0 / jnp.power(mid, p - 1)) * jnp.power(imp_x, p) imp_b = 1 - (1.0 / jnp.power(1 - mid, p - 1)) * jnp.power(1 - imp_x, p) imp_y = jnp.where(imp_x < mid, imp_a, imp_b) # Compute the impedance. ξ = ξ_min + imp_y * (ξ_max - ξ_min) ξ = jnp.clip(ξ, ξ_min, ξ_max) ξ = jnp.where(imp_x > 1.0, ξ_max, ξ) # Compute the spring and damper parameters during runtime from the # impedance and other contact parameters. K = 1 / (ξ_max * Ω * ζ) ** 2 D = 2 / (ξ_max * Ω) # If the user specifies K and D and they are negative, the computed `a_ref` # becomes something more similar to a classic Baumgarte regularization. K = jnp.where(K < 0, -K / ξ_max**2, K) D = jnp.where(D < 0, -D / ξ_max, D) # Compute the reference acceleration. a_ref = -(D * vel + K * ξ * pos) return ξ, a_ref, K, D def compute_row( *, link_idx: jtp.Int, pos: jtp.Vector, vel: jtp.Vector, ) -> tuple[jtp.Vector, jtp.Matrix, jtp.Vector, jtp.Vector]: # Compute the reference acceleration. ξ, a_ref, K, D = imp_aref(pos=pos, vel=vel) # Compute the regularization term. R = ( (2 * μ**2 * (1 - ξ) / (ξ + 1e-12)) * (1 + μ**2) @ jnp.linalg.inv(M_L[link_idx, :3, :3]) ) # Return the computed values, setting them to zero in case of no contact. is_active = (pos.dot(pos) > 0).astype(float) return jax.tree.map( lambda x: jnp.atleast_1d(x) * is_active, (a_ref, R, K, D) ) a_ref, R, K, D = jax.tree.map( f=jnp.concatenate, tree=( *jax.vmap(compute_row)( link_idx=parent_link_idx_of_enabled_collidable_points, pos=position_constraint, vel=velocity_constraint, ), ), ) return a_ref, R, K, D ================================================ FILE: src/jaxsim/rbda/contacts/rigid.py ================================================ from __future__ import annotations import dataclasses from typing import Any import jax import jax.numpy as jnp import jax_dataclasses import qpax import jaxsim.api as js import jaxsim.typing as jtp from jaxsim import logging from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr from . import common from .common import ContactModel, ContactsParams try: from typing import Self except ImportError: from typing_extensions import Self @jax_dataclasses.pytree_dataclass class RigidContactsParams(ContactsParams): """Parameters of the rigid contacts model.""" # Static friction coefficient mu: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(0.5, dtype=float) ) # Baumgarte proportional term K: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(0.0, dtype=float) ) # Baumgarte derivative term D: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(0.0, dtype=float) ) def __hash__(self) -> int: from jaxsim.utils.wrappers import HashedNumpyArray return hash( ( HashedNumpyArray.hash_of_array(self.mu), HashedNumpyArray.hash_of_array(self.K), HashedNumpyArray.hash_of_array(self.D), ) ) def __eq__(self, other: RigidContactsParams) -> bool: if not isinstance(other, RigidContactsParams): return False return hash(self) == hash(other) @classmethod def build( cls: type[Self], *, mu: jtp.FloatLike | None = None, K: jtp.FloatLike | None = None, D: jtp.FloatLike | None = None, **kwargs, ) -> Self: """Create a `RigidContactParams` instance.""" return cls( mu=jnp.array( mu if mu is not None else cls.__dataclass_fields__["mu"].default_factory() ).astype(float), K=jnp.array( K if K is not None else cls.__dataclass_fields__["K"].default_factory() ).astype(float), D=jnp.array( D if D is not None else cls.__dataclass_fields__["D"].default_factory() ).astype(float), ) def valid(self) -> jtp.BoolLike: """Check if the parameters are valid.""" return bool( jnp.all(self.mu >= 0.0) and jnp.all(self.K >= 0.0) and jnp.all(self.D >= 0.0) ) @jax_dataclasses.pytree_dataclass class RigidContacts(ContactModel): """Rigid contacts model.""" regularization_delassus: jax_dataclasses.Static[float] = dataclasses.field( default=1e-6, kw_only=True ) _solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field( default=("solver_tol",), kw_only=True ) _solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field( default=(1e-3,), kw_only=True ) @property def solver_options(self) -> dict[str, Any]: """Get the solver options as a dictionary.""" return dict( zip( self._solver_options_keys, self._solver_options_values, strict=True, ) ) @classmethod def build( cls: type[Self], regularization_delassus: jtp.FloatLike | None = None, solver_options: dict[str, Any] | None = None, **kwargs, ) -> Self: """ Create a `RigidContacts` instance with specified parameters. Args: regularization_delassus: The regularization term to add to the diagonal of the Delassus matrix. solver_options: The options to pass to the QP solver. **kwargs: Extra arguments which are ignored. Returns: The `RigidContacts` instance. """ if kwargs: logging.warning(msg=f"Ignoring extra arguments: {kwargs}") # Get the default solver options. default_solver_options = dict( zip(cls._solver_options_keys, cls._solver_options_values, strict=True) ) # Create the solver options to set by combining the default solver options # with the user-provided solver options. solver_options = default_solver_options | ( solver_options if solver_options is not None else {} ) # Make sure that the solver options are hashable. # We need to check this because the solver options are static. try: hash(tuple(solver_options.values())) except TypeError as exc: raise ValueError( "The values of the solver options must be hashable." ) from exc return cls( regularization_delassus=float( regularization_delassus if regularization_delassus is not None else cls.__dataclass_fields__["regularization_delassus"].default ), _solver_options_keys=tuple(solver_options.keys()), _solver_options_values=tuple(solver_options.values()), **kwargs, ) @staticmethod def compute_impact_velocity( inactive_collidable_points: jtp.ArrayLike, M: jtp.MatrixLike, J_WC: jtp.MatrixLike, generalized_velocity: jtp.VectorLike, ) -> jtp.Vector: """ Return the new velocity of the system after a potential impact. Args: inactive_collidable_points: The activation state of the collidable points. M: The mass matrix of the system (in mixed representation). J_WC: The Jacobian matrix of the collidable points (in mixed representation). generalized_velocity: The generalized velocity of the system. Note: The mass matrix `M`, the Jacobian `J_WC`, and the generalized velocity `generalized_velocity` must be expressed in the same velocity representation. """ # Compute system velocity after impact maintaining zero linear velocity of active points. sl = jnp.s_[:, 0:3, :] Jl_WC = J_WC[sl] # Zero out the jacobian rows of inactive points. Jl_WC = jnp.vstack( jnp.where( inactive_collidable_points[:, jnp.newaxis, jnp.newaxis], jnp.zeros_like(Jl_WC), Jl_WC, ) ) A = jnp.vstack( [ jnp.hstack([M, -Jl_WC.T]), jnp.hstack([Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))]), ] ) b = jnp.hstack([M @ generalized_velocity, jnp.zeros(Jl_WC.shape[0])]) BW_ν_post_impact = jnp.linalg.lstsq(A, b)[0] return BW_ν_post_impact[0 : M.shape[0]] @jax.jit @js.common.named_scope def compute_contact_forces( self, model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: """ Compute the contact forces. Args: model: The model to consider. data: The data of the considered model. link_forces: Optional `(n_links, 6)` matrix of external forces acting on the links, expressed in the same representation of data. joint_force_references: Optional `(n_joints,)` vector of joint forces. Returns: A tuple containing as first element the computed contact forces. """ # Get the indices of the enabled collidable points. indices_of_enabled_collidable_points = ( model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points ) n_collidable_points = len(indices_of_enabled_collidable_points) link_forces = jnp.atleast_2d( jnp.array(link_forces, dtype=float).squeeze() if link_forces is not None else jnp.zeros((model.number_of_links(), 6)) ) joint_force_references = jnp.atleast_1d( jnp.array(joint_force_references, dtype=float).squeeze() if joint_force_references is not None else jnp.zeros((model.number_of_joints(),)) ) # Build a references object to simplify converting link forces. references = js.references.JaxSimModelReferences.build( model=model, data=data, velocity_representation=data.velocity_representation, link_forces=link_forces, joint_force_references=joint_force_references, ) # Compute the position and linear velocities (mixed representation) of # all enabled collidable points belonging to the robot. position, velocity = js.contact.collidable_point_kinematics( model=model, data=data ) # Compute the penetration depth and velocity of the collidable points. # Note that this function considers the penetration in the normal direction. δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))( position, velocity, model.terrain ) W_H_C = js.contact.transforms(model=model, data=data) with ( references.switch_velocity_representation(VelRepr.Mixed), data.switch_velocity_representation(VelRepr.Mixed), ): # Compute kin-dyn quantities used in the contact model. BW_ν = data.generalized_velocity M_inv = js.model.free_floating_mass_matrix_inverse(model=model, data=data) J_WC = js.contact.jacobian(model=model, data=data) J̇_WC = js.contact.jacobian_derivative(model=model, data=data) # Compute the generalized free acceleration. BW_ν̇_free = jnp.hstack( js.model.forward_dynamics_aba( model=model, data=data, link_forces=references.link_forces(model=model, data=data), joint_forces=references.joint_force_references(model=model), ) ) # Compute the free linear acceleration of the collidable points. # Since we use doubly-mixed jacobian, this corresponds to W_p̈_C. free_contact_acc = _linear_acceleration_of_collidable_points( BW_nu=BW_ν, BW_nu_dot=BW_ν̇_free, CW_J_WC_BW=J_WC, CW_J_dot_WC_BW=J̇_WC, ).flatten() # Compute stabilization term. baumgarte_term = _compute_baumgarte_stabilization_term( inactive_collidable_points=(δ <= 0), δ=δ, δ_dot=δ_dot, n=n̂, K=model.contact_params.K, D=model.contact_params.D, ).flatten() # Compute the Delassus matrix. delassus_matrix = _delassus_matrix(M_inv=M_inv, J_WC=J_WC) # Initialize regularization term of the Delassus matrix for # better numerical conditioning. Iε = self.regularization_delassus * jnp.eye(delassus_matrix.shape[0]) # Construct the quadratic cost function. Q = delassus_matrix + Iε q = free_contact_acc - baumgarte_term # Construct the inequality constraints. G = _compute_ineq_constraint_matrix( inactive_collidable_points=(δ <= 0), mu=model.contact_params.mu ) h_bounds = jnp.zeros(shape=(n_collidable_points * 6,)) # Construct the equality constraints. A = jnp.zeros((0, 3 * n_collidable_points)) b = jnp.zeros((0,)) # Solve the following optimization problem with qpax: # # min_{x} 0.5 x⊤ Q x + q⊤ x # # s.t. A x = b # G x ≤ h # # TODO: add possibility to notify if the QP problem did not converge. solution, _, _, _, converged, _ = qpax.solve_qp( # noqa: RUF059 Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, **self.solver_options ) # Reshape the optimized solution to be a matrix of 3D contact forces. CW_fl_C = solution.reshape(-1, 3) # Convert the contact forces from mixed to inertial-fixed representation. W_f_C = jax.vmap( lambda CW_fl_C, W_H_C: ( ModelDataWithVelocityRepresentation.other_representation_to_inertial( array=jnp.zeros(6).at[0:3].set(CW_fl_C), transform=W_H_C, other_representation=VelRepr.Mixed, is_force=True, ) ), )(CW_fl_C, W_H_C) return W_f_C, {} @jax.jit @js.common.named_scope def update_velocity_after_impact( self: type[Self], model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> js.data.JaxSimModelData: """ Update the velocity after an impact. Args: model: The robot model considered by the contact model. data: The data of the considered model. Returns: The updated data of the considered model. """ # Extract the indices corresponding to the enabled collidable points. indices_of_enabled_collidable_points = ( model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points ) W_p_C = js.contact.collidable_point_positions(model, data)[ indices_of_enabled_collidable_points ] # Compute the penetration depth of the collidable points. δ, *_ = jax.vmap( common.compute_penetration_data, in_axes=(0, 0, None), )(W_p_C, jnp.zeros_like(W_p_C), model.terrain) with data.switch_velocity_representation(VelRepr.Mixed): J_WC = js.contact.jacobian(model, data)[ indices_of_enabled_collidable_points ] M = js.model.free_floating_mass_matrix(model, data) BW_ν_pre_impact = data.generalized_velocity # Compute the impact velocity. # It may be discontinuous in case new contacts are made. BW_ν_post_impact = RigidContacts.compute_impact_velocity( generalized_velocity=BW_ν_pre_impact, inactive_collidable_points=(δ <= 0), M=M, J_WC=J_WC, ) BW_ν_post_impact_inertial = data.other_representation_to_inertial( array=BW_ν_post_impact[0:6], other_representation=VelRepr.Mixed, transform=data._base_transform.at[0:3, 0:3].set(jnp.eye(3)), is_force=False, ) # Reset the generalized velocity. data = dataclasses.replace( data, _base_linear_velocity=BW_ν_post_impact_inertial[0:3], _base_angular_velocity=BW_ν_post_impact_inertial[3:6], _joint_velocities=BW_ν_post_impact[6:], ) return data def update_contact_state( self: type[Self], old_contact_state: dict[str, jtp.Array] ) -> dict[str, jtp.Array]: """ Update the contact state. Args: old_contact_state: The old contact state. Returns: The updated contact state. """ return {} @staticmethod def _delassus_matrix( M_inv: jtp.MatrixLike, J_WC: jtp.MatrixLike, ) -> jtp.Matrix: sl = jnp.s_[:, 0:3, :] J_WC_lin = jnp.vstack(J_WC[sl]) delassus_matrix = J_WC_lin @ M_inv @ J_WC_lin.T return delassus_matrix @jax.jit @js.common.named_scope def _compute_ineq_constraint_matrix( inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike ) -> jtp.Matrix: """ Compute the inequality constraint matrix for a single collidable point. Rows 0-3: enforce the friction pyramid constraint, Row 4: last one is for the non negativity of the vertical force Row 5: contact complementarity condition """ G_single_point = jnp.array( [ [1, 0, -mu], [0, 1, -mu], [-1, 0, -mu], [0, -1, -mu], [0, 0, -1], [0, 0, 0], ] ) G = jnp.tile(G_single_point, (len(inactive_collidable_points), 1, 1)) G = G.at[:, 5, 2].set(inactive_collidable_points) G = jax.scipy.linalg.block_diag(*G) return G @jax.jit @js.common.named_scope def _linear_acceleration_of_collidable_points( BW_nu: jtp.ArrayLike, BW_nu_dot: jtp.ArrayLike, CW_J_WC_BW: jtp.MatrixLike, CW_J_dot_WC_BW: jtp.MatrixLike, ) -> jtp.Matrix: BW_ν = BW_nu BW_ν̇ = BW_nu_dot CW_J̇_WC_BW = CW_J_dot_WC_BW # Compute the linear acceleration of the collidable points. # Since we use doubly-mixed jacobians, this corresponds to W_p̈_C. CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇ CW_a_WC = CW_a_WC.reshape(-1, 6) return CW_a_WC[:, 0:3].squeeze() @jax.jit @js.common.named_scope def _compute_baumgarte_stabilization_term( inactive_collidable_points: jtp.ArrayLike, δ: jtp.ArrayLike, δ_dot: jtp.ArrayLike, n: jtp.ArrayLike, K: jtp.FloatLike, D: jtp.FloatLike, ) -> jtp.Array: return jnp.where( inactive_collidable_points[:, jnp.newaxis], jnp.zeros_like(n), (K * δ + D * δ_dot)[:, jnp.newaxis] * n, ) ================================================ FILE: src/jaxsim/rbda/contacts/soft.py ================================================ from __future__ import annotations import dataclasses import functools import jax import jax.numpy as jnp import jax_dataclasses import jaxsim.api as js import jaxsim.math import jaxsim.typing as jtp from jaxsim import logging from jaxsim.terrain import Terrain from . import common try: from typing import Self except ImportError: from typing_extensions import Self @jax_dataclasses.pytree_dataclass class SoftContactsParams(common.ContactsParams): """Parameters of the soft contacts model.""" K: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(1e6, dtype=float) ) D: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(2000, dtype=float) ) mu: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(0.5, dtype=float) ) p: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(0.5, dtype=float) ) q: jtp.Float = dataclasses.field( default_factory=lambda: jnp.array(0.5, dtype=float) ) def __hash__(self) -> int: from jaxsim.utils.wrappers import HashedNumpyArray return hash( ( HashedNumpyArray.hash_of_array(self.K), HashedNumpyArray.hash_of_array(self.D), HashedNumpyArray.hash_of_array(self.mu), HashedNumpyArray.hash_of_array(self.p), HashedNumpyArray.hash_of_array(self.q), ) ) def __eq__(self, other: SoftContactsParams) -> bool: if not isinstance(other, SoftContactsParams): return False return hash(self) == hash(other) @classmethod def build( cls: type[Self], *, K: jtp.FloatLike = 1e6, D: jtp.FloatLike = 2_000, mu: jtp.FloatLike = 0.5, p: jtp.FloatLike = 0.5, q: jtp.FloatLike = 0.5, **kwargs, ) -> Self: """ Create a SoftContactsParams instance with specified parameters. Args: K: The stiffness parameter. D: The damping parameter of the soft contacts model. mu: The static friction coefficient. p: The exponent p corresponding to the damping-related non-linearity of the Hunt/Crossley model. q: The exponent q corresponding to the spring-related non-linearity of the Hunt/Crossley model **kwargs: Additional parameters to pass to the contact model. Returns: A SoftContactsParams instance with the specified parameters. """ return SoftContactsParams( K=jnp.array(K, dtype=float), D=jnp.array(D, dtype=float), mu=jnp.array(mu, dtype=float), p=jnp.array(p, dtype=float), q=jnp.array(q, dtype=float), ) def valid(self) -> jtp.BoolLike: """ Check if the parameters are valid. Returns: `True` if the parameters are valid, `False` otherwise. """ return jnp.hstack( [ self.K >= 0.0, self.D >= 0.0, self.mu >= 0.0, self.p >= 0.0, self.q >= 0.0, ] ).all() @jax_dataclasses.pytree_dataclass class SoftContacts(common.ContactModel): """Soft contacts model.""" @classmethod def build( cls: type[Self], **kwargs, ) -> Self: """ Create a `SoftContacts` instance with specified parameters. Args: **kwargs: Additional parameters to pass to the contact model. Returns: The `SoftContacts` instance. """ if kwargs: logging.warning(msg=f"Ignoring extra arguments: {kwargs}") return cls(**kwargs) @classmethod def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]: """ Build zero state variables of the contact model. """ # Initialize the material deformation to zero. tangential_deformation = jnp.zeros( shape=(len(model.kin_dyn_parameters.contact_parameters.body), 3), dtype=float, ) return {"tangential_deformation": tangential_deformation} def update_contact_state( self: type[Self], old_contact_state: dict[str, jtp.Array] ) -> dict[str, jtp.Array]: """ Update the contact state. Args: old_contact_state: The old contact state. Returns: The updated contact state. """ return {"tangential_deformation": old_contact_state["m_dot"]} def update_velocity_after_impact( self: type[Self], model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> js.data.JaxSimModelData: """ Update the velocity after an impact. Args: model: The robot model considered by the contact model. data: The data of the considered model. Returns: The updated data of the considered model. """ return data @staticmethod @functools.partial(jax.jit, static_argnames=("terrain",)) def hunt_crossley_contact_model( position: jtp.VectorLike, velocity: jtp.VectorLike, tangential_deformation: jtp.VectorLike, terrain: Terrain, K: jtp.FloatLike, D: jtp.FloatLike, mu: jtp.FloatLike, p: jtp.FloatLike = 0.5, q: jtp.FloatLike = 0.5, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute the contact force using the Hunt/Crossley model. Args: position: The position of the collidable point. velocity: The velocity of the collidable point. tangential_deformation: The material deformation of the collidable point. terrain: The terrain model. K: The stiffness parameter. D: The damping parameter of the soft contacts model. mu: The static friction coefficient. p: The exponent p corresponding to the damping-related non-linearity of the Hunt/Crossley model. q: The exponent q corresponding to the spring-related non-linearity of the Hunt/Crossley model Returns: A tuple containing the computed contact force and the derivative of the material deformation. """ # Convert the input vectors to arrays. W_p_C = jnp.array(position, dtype=float).squeeze() W_ṗ_C = jnp.array(velocity, dtype=float).squeeze() m = jnp.array(tangential_deformation, dtype=float).squeeze() # Use symbol for the static friction. μ = mu # Compute the penetration depth, its rate, and the considered terrain normal. δ, δ̇, n̂ = common.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=terrain) # There are few operations like computing the norm of a vector with zero length # or computing the square root of zero that are problematic in an AD context. # To avoid these issues, we introduce a small tolerance ε to their arguments # and make sure that we do not check them against zero directly. ε = jnp.finfo(float).eps # Compute the powers of the penetration depth. # Inject ε to address AD issues in differentiating the square root when # p and q are fractional. δp = jnp.power(δ + ε, p) δq = jnp.power(δ + ε, q) # ======================== # Compute the normal force # ======================== # Non-linear spring-damper model (Hunt/Crossley model). # This is the force magnitude along the direction normal to the terrain. force_normal_mag = (K * δp) * δ + (D * δq) * δ̇ # Depending on the magnitude of δ̇, the normal force could be negative. force_normal_mag = jnp.maximum(0.0, force_normal_mag) # Compute the 3D linear force in C[W] frame. f_normal = force_normal_mag * n̂ # ============================ # Compute the tangential force # ============================ # Extract the tangential component of the velocity. v_tangential = W_ṗ_C - jnp.dot(W_ṗ_C, n̂) * n̂ # Extract the normal and tangential components of the material deformation. m_normal = jnp.dot(m, n̂) * n̂ m_tangential = m - jnp.dot(m, n̂) * n̂ # Compute the tangential force in the sticking case. # Using the tangential component of the material deformation should not be # necessary if the sticking-slipping transition occurs in a terrain area # with a locally constant normal. However, this assumption is not true in # general, especially for highly uneven terrains. f_tangential = -((K * δp) * m_tangential + (D * δq) * v_tangential) # Detect the contact type (sticking or slipping). # Note that if there is no contact, sticking is set to True, and this detail # is exploited in the computation of the `contact_status` variable. sticking = jnp.logical_or( δ <= 0, f_tangential.dot(f_tangential) <= (μ * force_normal_mag) ** 2 ) # Compute the direction of the tangential force. # To prevent dividing by zero, we use a switch statement. norm = jaxsim.math.safe_norm(f_tangential) f_tangential_direction = f_tangential / ( norm + jnp.finfo(float).eps * (norm == 0) ) # Project the tangential force to the friction cone if slipping. f_tangential = jnp.where( sticking, f_tangential, jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction, ) # Set the tangential force to zero if there is no contact. f_tangential = jnp.where(δ <= 0, jnp.zeros(3), f_tangential) # ===================================== # Compute the material deformation rate # ===================================== # Compute the derivative of the material deformation. # Note that we included an additional relaxation of `m_normal` in the # sticking case, so that the normal deformation that could have accumulated # from a previous slipping phase can relax to zero. ṁ_no_contact = -(K / D) * m ṁ_sticking = v_tangential - (K / D) * m_normal ṁ_slipping = -(f_tangential + (K * δp) * m_tangential) / (D * δq) # Compute the contact status: # 0: slipping # 1: sticking # 2: no contact contact_status = sticking.astype(int) contact_status += (δ <= 0).astype(int) # Select the right material deformation rate depending on the contact status. ṁ = jax.lax.select_n(contact_status, ṁ_slipping, ṁ_sticking, ṁ_no_contact) # ========================================== # Compute and return the final contact force # ========================================== # Sum the normal and tangential forces. CW_fl = f_normal + f_tangential return CW_fl, ṁ @staticmethod @functools.partial(jax.jit, static_argnames=("terrain",)) def compute_contact_force( position: jtp.VectorLike, velocity: jtp.VectorLike, tangential_deformation: jtp.VectorLike, parameters: SoftContactsParams, terrain: Terrain, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute the contact force. Args: position: The position of the collidable point. velocity: The velocity of the collidable point. tangential_deformation: The material deformation of the collidable point. parameters: The parameters of the soft contacts model. terrain: The terrain model. Returns: A tuple containing the computed contact force and the derivative of the material deformation. """ CW_fl, ṁ = SoftContacts.hunt_crossley_contact_model( position=position, velocity=velocity, tangential_deformation=tangential_deformation, terrain=terrain, K=parameters.K, D=parameters.D, mu=parameters.mu, p=parameters.p, q=parameters.q, ) # Pack a mixed 6D force. CW_f = jnp.hstack([CW_fl, jnp.zeros(3)]) # Compute the 6D force transform from the mixed to the inertial-fixed frame. W_Xf_CW = jaxsim.math.Adjoint.from_quaternion_and_translation( translation=jnp.array(position), inverse=True ).T # Compute the 6D force in the inertial-fixed frame. W_f = W_Xf_CW @ CW_f return W_f, ṁ @staticmethod @jax.jit def compute_contact_forces( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: """ Compute the contact forces. Args: model: The model to consider. data: The data of the considered model. Returns: A tuple containing as first element the computed contact forces, and as second element a dictionary with derivative of the material deformation. """ # Get the indices of the enabled collidable points. indices_of_enabled_collidable_points = ( model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points ) # Compute the position and linear velocities (mixed representation) of # all the collidable points belonging to the robot and extract the ones # for the enabled collidable points. W_p_C, W_ṗ_C = js.contact.collidable_point_kinematics(model=model, data=data) # Extract the material deformation corresponding to the collidable points. m = ( data.contact_state["tangential_deformation"] if "tangential_deformation" in data.contact_state else jnp.zeros_like(W_p_C) ) m_enabled = m[indices_of_enabled_collidable_points] # Initialize the tangential deformation rate array for every collidable point. ṁ = jnp.zeros_like(m) # Compute the contact forces only for the enabled collidable points. # Since we treat them as independent, we can vmap the computation. W_f, ṁ_enabled = jax.vmap( lambda p, v, m: SoftContacts.compute_contact_force( position=p, velocity=v, tangential_deformation=m, parameters=model.contact_params, terrain=model.terrain, ) )(W_p_C, W_ṗ_C, m_enabled) ṁ = ṁ.at[indices_of_enabled_collidable_points].set(ṁ_enabled) return W_f, {"m_dot": ṁ} ================================================ FILE: src/jaxsim/rbda/crba.py ================================================ import jax import jax.numpy as jnp import jaxsim.api as js import jaxsim.typing as jtp from . import utils def crba( model: js.model.JaxSimModel, *, joint_positions: jtp.Vector, ) -> jtp.Matrix: """ Compute the free-floating mass matrix using the Composite Rigid-Body Algorithm (CRBA). Args: model: The model to consider. joint_positions: The positions of the joints. Returns: The free-floating mass matrix of the model in body-fixed representation. """ _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( model=model, joint_positions=joint_positions ) # Get the 6D spatial inertia matrices of all links. Mc = js.model.link_spatial_inertia_matrices(model=model) # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array # Compute the parent-to-child adjoints of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. i_X_λi = model.kin_dyn_parameters.joint_transforms( joint_positions=s, base_transform=jnp.eye(4) ) # Extract the joint motion subspaces. S = model.kin_dyn_parameters.motion_subspaces # Allocate the buffer of transforms link -> base. i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6)) i_X_0 = i_X_0.at[0].set(jnp.eye(6)) # ==================== # Propagate kinematics # ==================== ForwardPassCarry = tuple[jtp.Matrix] forward_pass_carry: ForwardPassCarry = (i_X_0,) def propagate_kinematics( carry: ForwardPassCarry, i: jtp.Int ) -> tuple[ForwardPassCarry, None]: (i_X_0,) = carry i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]] i_X_0 = i_X_0.at[i].set(i_X_0_i) return (i_X_0,), None (i_X_0,), _ = ( jax.lax.scan( f=propagate_kinematics, init=forward_pass_carry, xs=jnp.arange(start=1, stop=model.number_of_links()), ) if model.number_of_links() > 1 else [(i_X_0,), None] ) # =================== # Compute mass matrix # =================== M = jnp.zeros(shape=(6 + model.dofs(), 6 + model.dofs())) BackwardPassCarry = tuple[jtp.Matrix, jtp.Matrix] backward_pass_carry: BackwardPassCarry = (Mc, M) def backward_pass( carry: BackwardPassCarry, i: jtp.Int ) -> tuple[BackwardPassCarry, None]: ii = i - 1 Mc, M = carry Mc_λi = Mc[λ[i]] + i_X_λi[i].T @ Mc[i] @ i_X_λi[i] Mc = Mc.at[λ[i]].set(Mc_λi) Fi = Mc[i] @ S[i] M_ii = S[i].T @ Fi M = M.at[ii + 6, ii + 6].set(M_ii.squeeze()) j = i FakeWhileCarry = tuple[jtp.Int, jtp.Vector, jtp.Matrix] fake_while_carry = (j, Fi, M) # This internal for loop implements the while loop of the CRBA algorithm # to compute off-diagonal blocks of the mass matrix M. # In pseudocode it is implemented as a while loop. However, in order to enable # applying reverse-mode AD, we implement it as a nested for loop with a fixed # number of iterations and a branching model to skip for loop iterations. def fake_while_loop( carry: FakeWhileCarry, i: jtp.Int ) -> tuple[FakeWhileCarry, None]: def compute(carry: FakeWhileCarry) -> FakeWhileCarry: j, Fi, M = carry Fi = i_X_λi[j].T @ Fi j = λ[j] M_ij = Fi.T @ S[j] jj = j - 1 M = M.at[ii + 6, jj + 6].set(M_ij.squeeze()) M = M.at[jj + 6, ii + 6].set(M_ij.squeeze()) return j, Fi, M j, _, _ = carry j, Fi, M = jax.lax.cond( pred=jnp.logical_and(i == λ[j], λ[j] > 0), true_fun=compute, false_fun=lambda carry: carry, operand=carry, ) return (j, Fi, M), None (j, Fi, M), _ = ( jax.lax.scan( f=fake_while_loop, init=fake_while_carry, xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())), ) if model.number_of_links() > 1 else [(j, Fi, M), None] ) Fi = i_X_0[j].T @ Fi M = M.at[0:6, ii + 6].set(Fi.squeeze()) M = M.at[ii + 6, 0:6].set(Fi.squeeze()) return (Mc, M), None # This scan performs the backward pass to compute Mbj, Mjb and Mjj, that # also includes a fake while loop implemented with a scan and two cond. (Mc, M), _ = ( jax.lax.scan( f=backward_pass, init=backward_pass_carry, xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())), ) if model.number_of_links() > 1 else [(Mc, M), None] ) # Store the locked 6D rigid-body inertia matrix Mbb ∈ ℝ⁶ˣ⁶. M = M.at[0:6, 0:6].set(Mc[0]) return M ================================================ FILE: src/jaxsim/rbda/forward_kinematics.py ================================================ import jax import jax.numpy as jnp import jaxsim.api as js import jaxsim.typing as jtp from jaxsim.math import Adjoint from . import utils def forward_kinematics_model( model: js.model.JaxSimModel, *, base_position: jtp.VectorLike, base_quaternion: jtp.VectorLike, joint_positions: jtp.VectorLike, base_linear_velocity_inertial: jtp.VectorLike, base_angular_velocity_inertial: jtp.VectorLike, joint_velocities: jtp.VectorLike, joint_transforms: jtp.MatrixLike, ) -> tuple[jtp.Array, jtp.Array]: """ Compute the forward kinematics. Args: model: The model to consider. base_position: The position of the base link. base_quaternion: The quaternion of the base link. joint_positions: The positions of the joints. base_linear_velocity_inertial: The linear velocity of the base link in inertial-fixed representation. base_angular_velocity_inertial: The angular velocity of the base link in inertial-fixed representation. joint_velocities: The velocities of the joints. joint_transforms: The parent-to-child transforms of the joints. Returns: A tuple containing the SE(3) transforms and the 6D velocities of all links. """ _, _, _, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs( model=model, base_position=base_position, base_quaternion=base_quaternion, joint_positions=joint_positions, base_linear_velocity=base_linear_velocity_inertial, base_angular_velocity=base_angular_velocity_inertial, joint_velocities=joint_velocities, ) # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array # Extract the parent-to-child adjoints of the joints. i_X_λi = jnp.asarray(joint_transforms) # Allocate the buffer of transforms world -> link and initialize the base pose. W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) W_X_i = W_X_i.at[0].set(Adjoint.inverse(i_X_λi[0])) # Allocate buffer of 6D inertial-fixed velocities and initialize the base velocity. W_v_Wi = jnp.zeros(shape=(model.number_of_links(), 6)) W_v_Wi = W_v_Wi.at[0].set(W_v_WB) # Extract the joint motion subspaces. S = model.kin_dyn_parameters.motion_subspaces # ======================== # Propagate the kinematics # ======================== PropagateKinematicsCarry = tuple[jtp.Matrix, jtp.Matrix] propagate_kinematics_carry: PropagateKinematicsCarry = (W_X_i, W_v_Wi) def propagate_kinematics( carry: PropagateKinematicsCarry, i: jtp.Int ) -> tuple[PropagateKinematicsCarry, None]: ii = i - 1 W_X_i, W_v_Wi = carry # Compute the parent to child 6D transform. λi_X_i = Adjoint.inverse(adjoint=i_X_λi[i]) # Compute the world to child 6D transform. W_Xi_i = W_X_i[λ[i]] @ λi_X_i W_X_i = W_X_i.at[i].set(W_Xi_i) # Propagate the 6D velocity. W_vi_Wi = W_v_Wi[λ[i]] + W_X_i[i] @ (S[i] * ṡ[ii]).squeeze() W_v_Wi = W_v_Wi.at[i].set(W_vi_Wi) return (W_X_i, W_v_Wi), None (W_X_i, W_v_Wi), _ = ( jax.lax.scan( f=propagate_kinematics, init=propagate_kinematics_carry, xs=jnp.arange(start=1, stop=model.number_of_links()), ) if model.number_of_links() > 1 else [(W_X_i, W_v_Wi), None] ) return jax.vmap(Adjoint.to_transform)(W_X_i), W_v_Wi ================================================ FILE: src/jaxsim/rbda/forward_kinematics_parallel.py ================================================ import math import jax import jax.numpy as jnp import jaxsim.api as js import jaxsim.typing as jtp from jaxsim.math import Adjoint from . import utils def forward_kinematics_model_parallel( model: js.model.JaxSimModel, *, base_position: jtp.VectorLike, base_quaternion: jtp.VectorLike, joint_positions: jtp.VectorLike, base_linear_velocity_inertial: jtp.VectorLike, base_angular_velocity_inertial: jtp.VectorLike, joint_velocities: jtp.VectorLike, joint_transforms: jtp.MatrixLike, ) -> tuple[jtp.Array, jtp.Array]: """ Compute forward kinematics using pointer jumping on the kinematic tree. Uses an associative binary operator on transform-velocity pairs to compute all world-frame transforms and velocities in O(log D) parallel steps, where D is the tree depth. The interface and semantics are identical to :func:`forward_kinematics_model`, but parallelized via pointer jumping. """ _, _, _, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs( model=model, base_position=base_position, base_quaternion=base_quaternion, joint_positions=joint_positions, base_linear_velocity=base_linear_velocity_inertial, base_angular_velocity=base_angular_velocity_inertial, joint_velocities=joint_velocities, ) # Extract the parent-to-child adjoints of the joints. i_X_λi = jnp.asarray(joint_transforms) # Extract the joint motion subspaces. S = model.kin_dyn_parameters.motion_subspaces n = model.number_of_links() # Compute local transforms λ(i)_X_i by inverting the child-to-parent adjoints. L = jax.vmap(Adjoint.inverse)(i_X_λi) # (n, 6, 6) # Compute local velocity contributions. ṡ_padded = jnp.concatenate([jnp.zeros(1), jnp.atleast_1d(ṡ.squeeze())]) # (n,) vJ = (S * ṡ_padded[:, None, None]).squeeze(-1) # (n, 6) u = jnp.einsum("nij,nj->ni", L, vJ) # (n, 6) u = u.at[0].set(W_v_WB) # Get the parent array λ(i) with root self-loop. # Note: λ(0) is set to 0 to enable root self-referencing. ptr = jnp.asarray(model.kin_dyn_parameters.parent_array).at[0].set(0) done = jnp.arange(n) == 0 # Number of pointer-jumping rounds. n_levels = model.kin_dyn_parameters.level_nodes.shape[0] n_rounds = max(1, math.ceil(math.log2(max(n_levels, 2)))) # =============== # Pointer jumping # =============== # Each round composes the node state with its current pointer target, # then doubles the jump distance. After ceil(log2 D) rounds every node # has accumulated the full root-to-node transform and velocity. def _pointer_jump(carry, _): L, u, ptr, done = carry need = ~done L_par = L[ptr] u_par = u[ptr] # Associative compose. L_new = jnp.where(need[:, None, None], L_par @ L, L) u_new = jnp.where( need[:, None], u_par + jnp.einsum("nij,nj->ni", L_par, u), u, ) ptr_new = jnp.where(need, ptr[ptr], ptr) done_new = done | done[ptr] return (L_new, u_new, ptr_new, done_new), None (W_X_i, W_v_Wi, _, _), _ = ( jax.lax.scan( f=_pointer_jump, init=(L, u, ptr, done), xs=jnp.arange(n_rounds), ) if n > 1 else ((L, u, ptr, done), None) ) return jax.vmap(Adjoint.to_transform)(W_X_i), W_v_Wi ================================================ FILE: src/jaxsim/rbda/jacobian.py ================================================ import jax import jax.numpy as jnp import numpy as np import jaxsim.api as js import jaxsim.typing as jtp from jaxsim.math import Adjoint, Cross from . import utils def jacobian( model: js.model.JaxSimModel, *, link_index: jtp.Int, joint_positions: jtp.VectorLike, ) -> jtp.Matrix: """ Compute the free-floating Jacobian of a link. Args: model: The model to consider. link_index: The index of the link for which to compute the Jacobian matrix. joint_positions: The positions of the joints. Returns: The free-floating left-trivialized Jacobian of the link :math:`{}^L J_{W,L/B}`. """ _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( model=model, joint_positions=joint_positions ) # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array # Compute the parent-to-child adjoints of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. i_X_λi = model.kin_dyn_parameters.joint_transforms( joint_positions=s, base_transform=jnp.eye(4) ) # Extract the joint motion subspaces. S = model.kin_dyn_parameters.motion_subspaces # Allocate the buffer of transforms link -> base. i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6)) i_X_0 = i_X_0.at[0].set(jnp.eye(6)) # ==================== # Propagate kinematics # ==================== PropagateKinematicsCarry = tuple[jtp.Matrix] propagate_kinematics_carry: PropagateKinematicsCarry = (i_X_0,) def propagate_kinematics( carry: PropagateKinematicsCarry, i: jtp.Int ) -> tuple[PropagateKinematicsCarry, None]: (i_X_0,) = carry # Compute the base (0) to link (i) adjoint matrix. # This works fine since we traverse the kinematic tree following the link # indices assigned with BFS. i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]] i_X_0 = i_X_0.at[i].set(i_X_0_i) return (i_X_0,), None (i_X_0,), _ = ( jax.lax.scan( f=propagate_kinematics, init=propagate_kinematics_carry, xs=np.arange(start=1, stop=model.number_of_links()), ) if model.number_of_links() > 1 else [(i_X_0,), None] ) # ============================ # Compute doubly-left Jacobian # ============================ J = jnp.zeros(shape=(6, 6 + model.dofs())) Jb = i_X_0[link_index] J = J.at[0:6, 0:6].set(Jb) # To make JIT happy, we operate on a boolean version of κ(i). # Checking if j ∈ κ(i) is equivalent to: κ_bool(j) is True. κ_bool = model.kin_dyn_parameters.support_body_array_bool[link_index] def compute_jacobian(J: jtp.Matrix, i: jtp.Int) -> tuple[jtp.Matrix, None]: def update_jacobian(J: jtp.Matrix, i: jtp.Int) -> jtp.Matrix: ii = i - 1 Js_i = i_X_0[link_index] @ Adjoint.inverse(i_X_0[i]) @ S[i] J = J.at[0:6, 6 + ii].set(Js_i.squeeze()) return J J = jax.lax.select( pred=κ_bool[i], on_true=update_jacobian(J, i), on_false=J, ) return J, None L_J_WL_B, _ = ( jax.lax.scan( f=compute_jacobian, init=J, xs=np.arange(start=1, stop=model.number_of_links()), ) if model.number_of_links() > 1 else [J, None] ) return L_J_WL_B @jax.jit def jacobian_full_doubly_left( model: js.model.JaxSimModel, *, joint_positions: jtp.VectorLike, ) -> tuple[jtp.Matrix, jtp.Array]: r""" Compute the doubly-left full free-floating Jacobian of a model. The full Jacobian is a 6x(6+n) matrix with all the columns filled. It is useful to run the algorithm once, and then extract the link Jacobian by filtering the columns of the full Jacobian using the support parent array :math:`\kappa(i)` of the link. Args: model: The model to consider. joint_positions: The positions of the joints. Returns: The doubly-left full free-floating Jacobian of a model. """ _, _, s, _, _, _, _, _, _, _ = utils.process_inputs( model=model, joint_positions=joint_positions ) # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array # Compute the parent-to-child adjoints of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. i_X_λi = model.kin_dyn_parameters.joint_transforms( joint_positions=s, base_transform=jnp.eye(4) ) # Extract the joint motion subspaces. S = model.kin_dyn_parameters.motion_subspaces # Allocate the buffer of transforms base -> link. B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) B_X_i = B_X_i.at[0].set(jnp.eye(6)) # ================================= # Compute doubly-left full Jacobian # ================================= # Allocate the Jacobian matrix. # The Jbb section of the doubly-left Jacobian is an identity matrix. J = jnp.zeros(shape=(6, 6 + model.dofs())) J = J.at[0:6, 0:6].set(jnp.eye(6)) ComputeFullJacobianCarry = tuple[jtp.Matrix, jtp.Matrix] compute_full_jacobian_carry: ComputeFullJacobianCarry = (B_X_i, J) def compute_full_jacobian( carry: ComputeFullJacobianCarry, i: jtp.Int ) -> tuple[ComputeFullJacobianCarry, None]: ii = i - 1 B_X_i, J = carry # Compute the base (0) to link (i) adjoint matrix. B_Xi_i = B_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i]) B_X_i = B_X_i.at[i].set(B_Xi_i) # Compute the ii-th column of the B_S_BL(s) matrix. B_Sii_BL = B_Xi_i @ S[i] J = J.at[0:6, 6 + ii].set(B_Sii_BL.squeeze()) return (B_X_i, J), None (B_X_i, J), _ = ( jax.lax.scan( f=compute_full_jacobian, init=compute_full_jacobian_carry, xs=np.arange(start=1, stop=model.number_of_links()), ) if model.number_of_links() > 1 else [(B_X_i, J), None] ) # Convert adjoints to SE(3) transforms. # Returning them here prevents calling FK in case the output representation # of the Jacobian needs to be changed. B_H_L = jax.vmap(Adjoint.to_transform)(B_X_i) # Adjust shape of doubly-left free-floating full Jacobian. B_J_full_WL_B = J.squeeze().astype(float) return B_J_full_WL_B, B_H_L def jacobian_derivative_full_doubly_left( model: js.model.JaxSimModel, *, joint_positions: jtp.VectorLike, joint_velocities: jtp.VectorLike, ) -> tuple[jtp.Matrix, jtp.Array]: r""" Compute the derivative of the doubly-left full free-floating Jacobian of a model. The derivative of the full Jacobian is a 6x(6+n) matrix with all the columns filled. It is useful to run the algorithm once, and then extract the link Jacobian derivative by filtering the columns of the full Jacobian using the support parent array :math:`\kappa(i)` of the link. Args: model: The model to consider. joint_positions: The positions of the joints. joint_velocities: The velocities of the joints. Returns: The derivative of the doubly-left full free-floating Jacobian of a model. """ _, _, s, _, ṡ, _, _, _, _, _ = utils.process_inputs( model=model, joint_positions=joint_positions, joint_velocities=joint_velocities ) # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array # Compute the parent-to-child adjoints of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. i_X_λi = model.kin_dyn_parameters.joint_transforms( joint_positions=s, base_transform=jnp.eye(4) ) # Extract the joint motion subspaces. S = model.kin_dyn_parameters.motion_subspaces # Allocate the buffer of 6D transform base -> link. B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) B_X_i = B_X_i.at[0].set(jnp.eye(6)) # Allocate the buffer of 6D transform derivatives base -> link. B_Ẋ_i = jnp.zeros(shape=(model.number_of_links(), 6, 6)) # Allocate the buffer of the 6D link velocity in body-fixed representation. B_v_Bi = jnp.zeros(shape=(model.number_of_links(), 6)) # Helper to compute the time derivative of the adjoint matrix. def A_Ẋ_B(A_X_B: jtp.Matrix, B_v_AB: jtp.Vector) -> jtp.Matrix: return A_X_B @ Cross.vx(B_v_AB).squeeze() # ============================================ # Compute doubly-left full Jacobian derivative # ============================================ # Allocate the Jacobian matrix. J̇ = jnp.zeros(shape=(6, 6 + model.dofs())) ComputeFullJacobianDerivativeCarry = tuple[ jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix ] compute_full_jacobian_derivative_carry: ComputeFullJacobianDerivativeCarry = ( B_v_Bi, B_X_i, B_Ẋ_i, J̇, ) def compute_full_jacobian_derivative( carry: ComputeFullJacobianDerivativeCarry, i: jtp.Int ) -> tuple[ComputeFullJacobianDerivativeCarry, None]: ii = i - 1 B_v_Bi, B_X_i, B_Ẋ_i, J̇ = carry # Compute the base (0) to link (i) adjoint matrix. B_Xi_i = B_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i]) B_X_i = B_X_i.at[i].set(B_Xi_i) # Compute the body-fixed velocity of the link. B_vi_Bi = B_v_Bi[λ[i]] + B_X_i[i] @ S[i].squeeze() * ṡ[ii] B_v_Bi = B_v_Bi.at[i].set(B_vi_Bi) # Compute the base (0) to link (i) adjoint matrix derivative. i_Xi_B = Adjoint.inverse(B_Xi_i) B_Ẋi_i = A_Ẋ_B(A_X_B=B_Xi_i, B_v_AB=i_Xi_B @ B_vi_Bi) B_Ẋ_i = B_Ẋ_i.at[i].set(B_Ẋi_i) # Compute the ii-th column of the B_Ṡ_BL(s) matrix. B_Ṡii_BL = B_Ẋ_i[i] @ S[i] J̇ = J̇.at[0:6, 6 + ii].set(B_Ṡii_BL.squeeze()) return (B_v_Bi, B_X_i, B_Ẋ_i, J̇), None (_, B_X_i, B_Ẋ_i, J̇), _ = ( jax.lax.scan( f=compute_full_jacobian_derivative, init=compute_full_jacobian_derivative_carry, xs=np.arange(start=1, stop=model.number_of_links()), ) if model.number_of_links() > 1 else [(_, B_X_i, B_Ẋ_i, J̇), None] ) # Convert adjoints to SE(3) transforms. # Returning them here prevents calling FK in case the output representation # of the Jacobian needs to be changed. B_H_L = jax.vmap(Adjoint.to_transform)(B_X_i) # Adjust shape of doubly-left free-floating full Jacobian derivative. B_J̇_full_WL_B = J̇.squeeze().astype(float) return B_J̇_full_WL_B, B_H_L ================================================ FILE: src/jaxsim/rbda/kinematic_constraints.py ================================================ from __future__ import annotations import jax import jax.numpy as jnp import jaxsim.api as js import jaxsim.typing as jtp from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr from jaxsim.api.kin_dyn_parameters import ConstraintMap from jaxsim.math.adjoint import Adjoint from jaxsim.math.rotation import Rotation from jaxsim.math.transform import Transform # Utility functions used for constraints computation. These functions duplicate part of the jaxsim.api.frame module for computational efficiency. # TODO: remove these functions when jaxsim.api.frame is optimized for batched computations. # See: https://github.com/gbionics/jaxsim/issues/451 def _compute_constraint_transforms_batched( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, constraints: ConstraintMap, ) -> jtp.Matrix: """ Compute the transformation matrices for kinematic constraints between pairs of frames. Args: model: The JaxSim model containing the robot description. data: The model data containing current state information. constraints: The constraint map containing frame indices and parent link information. Returns: A matrix with shape (n_constraints, 2, 4, 4) containing the transformation matrices for each constraint pair. The second dimension contains [W_H_F1, W_H_F2] where W_H_F1 and W_H_F2 are the world-to-frame transformation matrices. """ W_H_L = data._link_transforms frame_idxs_1 = constraints.frame_idxs_1 frame_idxs_2 = constraints.frame_idxs_2 parent_link_idxs_1 = constraints.parent_link_idxs_1 parent_link_idxs_2 = constraints.parent_link_idxs_2 # Extract frame transforms L_H_F1 = model.kin_dyn_parameters.frame_parameters.transform[ frame_idxs_1 - model.number_of_links() ] L_H_F2 = model.kin_dyn_parameters.frame_parameters.transform[ frame_idxs_2 - model.number_of_links() ] # Compute the homogeneous transformation matrices for the two frames W_H_F1 = W_H_L[parent_link_idxs_1] @ L_H_F1 W_H_F2 = W_H_L[parent_link_idxs_2] @ L_H_F2 return jnp.stack([W_H_F1, W_H_F2], axis=1) def _compute_constraint_jacobians_batched( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, constraints: ConstraintMap, W_H_constraint_pairs: jtp.Matrix, ) -> jtp.Matrix: """ Compute the constraint Jacobian matrices for kinematic constraints in a batched manner. Args: model: The JaxSim model containing the robot description. data: The model data containing current state information. constraints: The constraint map containing frame indices and parent link information. W_H_constraint_pairs: Transformation matrices for constraint frame pairs with shape (n_constraints, 2, 4, 4). Returns: A matrix with shape (n_constraints, 6, n_dofs) containing the constraint Jacobian matrices. """ with data.switch_velocity_representation(VelRepr.Body): # Doubly-left free-floating Jacobian. L_J_WL_B = js.model.generalized_free_floating_jacobian( model=model, data=data, output_vel_repr=VelRepr.Body ) # Link transforms W_H_L = data._link_transforms def compute_frame_jacobian_mixed(L_J_WL, W_H_L, W_H_F, parent_link_index): """Compute the jacobian of a frame in mixed representation.""" # Select the jacobian of the parent link L_J_WL = L_J_WL[parent_link_index] # Compute the jacobian of the frame in mixed representation W_H_L = W_H_L[parent_link_index] F_H_L = Transform.inverse(W_H_F) @ W_H_L FW_H_F = W_H_F.at[0:3, 3].set(jnp.zeros(3)) FW_H_L = FW_H_F @ F_H_L FW_X_L = Adjoint.from_transform(transform=FW_H_L) FW_J_WL = FW_X_L @ L_J_WL O_J_WL_I = FW_J_WL return O_J_WL_I def compute_constraint_jacobian(L_J_WL, W_H_F, constraint): """Compute the constraint jacobian for a single constraint pair.""" J_WF1 = compute_frame_jacobian_mixed( L_J_WL, W_H_L, W_H_F[0], constraint.parent_link_idxs_1 ) J_WF2 = compute_frame_jacobian_mixed( L_J_WL, W_H_L, W_H_F[1], constraint.parent_link_idxs_2 ) return J_WF1 - J_WF2 # Vectorize the computation of constraint Jacobians constraint_jacobians = jax.vmap(compute_constraint_jacobian, in_axes=(None, 0, 0))( L_J_WL_B, W_H_constraint_pairs, constraints ) return constraint_jacobians def _compute_constraint_baumgarte_term( J_constr: jtp.Matrix, nu: jtp.Vector, W_H_F_constr: jtp.Matrix, constraint: ConstraintMap, ) -> jtp.Vector: """ Compute the Baumgarte stabilization term for kinematic constraints. The Baumgarte stabilization method is used to stabilize constraint violations by adding proportional and derivative terms to the constraint equation. This helps prevent constraint drift and improves numerical stability. Args: J_constr: The constraint Jacobian matrix with shape (6, n_dofs). nu: The generalized velocity vector with shape (n_dofs,). W_H_F_constr: Array containing the homogeneous transformation matrices of two frames [W_H_F1, W_H_F2] with respect to the world frame, with shape (2, 4, 4). constraint: The constraint object containing stabilization gains K_P and K_D. Returns: The computed Baumgarte stabilization term. """ W_H_F1, W_H_F2 = W_H_F_constr W_p_F1 = W_H_F1[0:3, 3] W_p_F2 = W_H_F2[0:3, 3] W_R_F1 = W_H_F1[0:3, 0:3] W_R_F2 = W_H_F2[0:3, 0:3] K_P = constraint.K_P K_D = constraint.K_D vel_error = J_constr @ nu position_error = W_p_F1 - W_p_F2 R_error = W_R_F2.T @ W_R_F1 orientation_error = Rotation.log_vee(R_error) baumgarte_term = ( K_P * jnp.concatenate([position_error, orientation_error]) + K_D * vel_error ) return baumgarte_term @jax.jit @js.common.named_scope def compute_constraint_wrenches( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, joint_force_references: jtp.VectorLike | None = None, link_forces_inertial: jtp.MatrixLike | None = None, regularization: jtp.Float = 1e-3, ) -> jtp.Matrix: """ Compute the constraint wrenches for kinematic constraints. This function solves the constraint forces needed to satisfy kinematic constraints between pairs of frames. It uses the Baumgarte stabilization method and computes the constraint wrenches in inertial representation. Args: model: The JaxSim model. data: The model data. joint_force_references: Optional joint force/torque references to apply. If None, zero forces are used. link_forces_inertial: Optional link forces applied in inertial representation. If None, zero forces are used. regularization: Regularization parameter for the constraint solver to improve numerical stability. Default is 1e-3. Returns: Array with shape (n_constraints, 2, 6) containing constraint wrench pairs in inertial representation. Each constraint produces two equal and opposite wrenches applied to the constrained frames. """ # Retrieve the kinematic constraints, if any. kin_constraints = model.kin_dyn_parameters.constraints n_kin_constraints = ( 6 * kin_constraints.frame_idxs_1.shape[0] if kin_constraints is not None and kin_constraints.frame_idxs_1.shape[0] > 0 else 0 ) # Return empty results if no constraints exist if n_kin_constraints == 0: return jnp.zeros((0, 2, 6)) # Build joint forces if not provided τ_references = ( jnp.asarray(joint_force_references, dtype=float) if joint_force_references is not None else jnp.zeros_like(data.joint_positions) ) # Build link forces if not provided W_f_L = ( jnp.atleast_2d(jnp.array(link_forces_inertial).squeeze()) if link_forces_inertial is not None else jnp.zeros((model.number_of_links(), 6)) ).astype(float) # Create references object for handling different velocity representations references = js.references.JaxSimModelReferences.build( model=model, joint_force_references=τ_references, link_forces=W_f_L, velocity_representation=VelRepr.Inertial, ) with ( data.switch_velocity_representation(VelRepr.Mixed), references.switch_velocity_representation(VelRepr.Mixed), ): BW_ν = data.generalized_velocity # Compute free acceleration without constraints BW_ν̇_free = jnp.hstack( js.model.forward_dynamics_aba( model=model, data=data, link_forces=references.link_forces(model=model, data=data), joint_forces=references.joint_force_references(model=model), ) ) # Compute mass matrix M_inv = js.model.free_floating_mass_matrix_inverse(model=model, data=data) W_H_constr_pairs = _compute_constraint_transforms_batched( model=model, data=data, constraints=kin_constraints, ) # Compute constraint jacobians J_constr = _compute_constraint_jacobians_batched( model=model, data=data, constraints=kin_constraints, W_H_constraint_pairs=W_H_constr_pairs, ) # Compute Baumgarte stabilization term constr_baumgarte_term = jnp.ravel( jax.vmap( _compute_constraint_baumgarte_term, in_axes=(0, None, 0, 0), )( J_constr, BW_ν, W_H_constr_pairs, kin_constraints, ), ) # Stack constraint jacobians J_constr = jnp.vstack(J_constr) # Compute Delassus matrix for constraints G_constraints = J_constr @ M_inv @ J_constr.T # Compute constraint acceleration # TODO: add J̇_constr with efficient computation CW_al_free_constr = J_constr @ BW_ν̇_free # Setup constraint optimization problem constraint_regularization = regularization * jnp.ones(n_kin_constraints) R = jnp.diag(constraint_regularization) A = G_constraints + R b = CW_al_free_constr + constr_baumgarte_term # Solve for constraint forces kin_constr_wrench_mixed = jnp.linalg.solve(A, -b).reshape(-1, 6) def transform_wrenches_to_inertial(wrench, transform_pair): """ Transform wrench pairs in inertial representation. Args: wrench: Wrench vector with shape (6,). transform_pair: Pair of transformation matrices [W_H_F1, W_H_F2] Returns: Stack of transformed wrenches with shape (2, 6). """ W_H_F1, W_H_F2 = transform_pair[0], transform_pair[1] wrench_F1 = wrench wrench_F2 = -wrench # Create wrench pair directly # Transform both at once wrench_F1_inertial = ( ModelDataWithVelocityRepresentation.other_representation_to_inertial( array=wrench_F1, transform=W_H_F1, other_representation=VelRepr.Mixed, is_force=True, ) ) wrench_F2_inertial = ( ModelDataWithVelocityRepresentation.other_representation_to_inertial( array=wrench_F2, transform=W_H_F2, other_representation=VelRepr.Mixed, is_force=True, ) ) return jnp.stack([wrench_F1_inertial, wrench_F2_inertial]) kin_constr_wrench_pairs_inertial = jax.vmap(transform_wrenches_to_inertial)( kin_constr_wrench_mixed, W_H_constr_pairs ) return kin_constr_wrench_pairs_inertial ================================================ FILE: src/jaxsim/rbda/mass_inverse.py ================================================ import jax import jax.numpy as jnp import jaxsim.api as js import jaxsim.typing as jtp def mass_inverse( model: js.model.JaxSimModel, *, joint_transforms: jtp.MatrixLike, ) -> jtp.Matrix: """ Compute the inverse of the mass matrix using an ABA-like algorithm. The implementation follows the approach described in https://laas.hal.science/hal-01790934v2. Args: model: The model to consider. joint_transforms: The parent-to-child transforms of the joints. Returns: The inverse of the mass matrix. """ # Get the 6D spatial inertia matrices of all links. I_A = js.model.link_spatial_inertia_matrices(model=model) # Get the parent array λ(i). # λ[0] ~ -1 (world) # λ[i] = parent link index for link i. λ = model.kin_dyn_parameters.parent_array # Extract the parent-to-child adjoints of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. i_X_λi = jnp.asarray(joint_transforms) # Extract the joint motion subspaces. S = model.kin_dyn_parameters.motion_subspaces NB = model.number_of_links() N = model.number_of_joints() # Total generalized velocities: 6 base + N. nv = N + 6 # Allocate buffers. F = jnp.zeros((NB, 6, nv), dtype=float) P = jnp.zeros((NB, 6, nv), dtype=float) U = jnp.zeros((NB, 6), dtype=float) D = jnp.zeros((NB,), dtype=float) # Pre-allocate mass matrix inverse M_inv = jnp.zeros((nv, nv), dtype=float) # Pre-compute indices. idx_fwd = jnp.arange(1, NB) idx_rev = jnp.arange(NB - 1, 0, -1) # ============= # Backward Pass # ============= BackwardPassCarry = tuple[ jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix ] backward_pass_carry: BackwardPassCarry = (I_A, F, U, D, M_inv) def loop_backward_pass( carry: BackwardPassCarry, i: jtp.Int ) -> tuple[BackwardPassCarry, None]: I_A, F, U, D, M_inv = carry Si = jnp.squeeze(S[i], axis=-1) Fi = F[i] Xi = i_X_λi[i] parent = λ[i] Ui = I_A[i] @ Si Di = jnp.dot(Si, Ui) U = U.at[i].set(Ui) D = D.at[i].set(Di) # Row index in ν for joint i: 6 + (i - 1) r = 6 + (i - 1) Minv_row = M_inv[r] # Diagonal element Minv_row = Minv_row.at[r].add(1.0 / Di) # Off-diagonals: Minv[r,:] -= (1/Di) * Sᵢᵀ Fᵢ sTFi = jnp.einsum("s,sn->n", Si, Fi) Minv_row = Minv_row - sTFi / Di M_inv = M_inv.at[r].set(Minv_row) # Propagate to parent if any (parent >= 0) def propagate(IA_F): I_A_, F_ = IA_F Ui_col = Ui[:, None] # F_a_i = F_i + U_i * Minv[r,:] Fa_i = Fi + Ui_col @ Minv_row[None, :] # F_parent += Xᵢᵀ F_a_i F_parent_new = F_[parent] + Xi.T @ Fa_i F_ = F_.at[parent].set(F_parent_new) # I_a_i = IAi - U_i D_i^{-1} U_iᵀ Ia_i = I_A[i] - jnp.outer(Ui, Ui) / Di # I_A[parent] += Xᵢᵀ I_a_i Xᵢ I_parent_new = I_A_[parent] + Xi.T @ Ia_i @ Xi I_A_ = I_A_.at[parent].set(I_parent_new) return I_A_, F_ I_A, F = jax.lax.cond( parent >= 0, propagate, lambda IA_F: IA_F, (I_A, F), ) return (I_A, F, U, D, M_inv), None (I_A, F, U, D, M_inv), _ = jax.lax.scan( loop_backward_pass, backward_pass_carry, idx_rev ) S0 = jnp.eye(6, dtype=float) U0 = I_A[0] @ S0 D0 = S0.T @ U0 D0_inv = jnp.linalg.inv(D0) # Base rows 0..5 in ν base_rows = slice(0, 6) # Diagonal base block M_inv = M_inv.at[base_rows, base_rows].add(D0_inv) # Off-diagonal base contribution: M_inv[base,:] -= D0^{-T} F[0] term0 = D0_inv.T @ F[0] M_inv = M_inv.at[base_rows, :].add(-term0) # ============ # Forward Pass # ============ # Initialize P_0 = S0 * Minv[base,:] = I * Minv[base,:] Minv_base = M_inv[base_rows, :] P = P.at[0].set(Minv_base) ForwardPassCarry = tuple[jtp.Matrix, jtp.Matrix] forward_pass_carry: ForwardPassCarry = (M_inv, P) def loop_forward_pass( carry: ForwardPassCarry, i: jtp.Int ) -> tuple[ForwardPassCarry, None]: M_inv, P = carry Si = jnp.squeeze(S[i], axis=-1) Ui = U[i] Di = D[i] Xi = i_X_λi[i] parent = λ[i] P_parent = jax.lax.cond( parent >= 0, lambda P_: P_[parent], lambda P_: jnp.zeros_like(P_[i]), P, ) # Row index in ν for joint i r = 6 + (i - 1) # Row update: M_inv[r,:] -= D_i^{-1} U_iᵀ Xᵢ P_parent def update_row(Minv_): X_P = Xi @ P_parent UiT_XP = jnp.einsum("s,sn->n", Ui, X_P) Minv_row = Minv_[r, :] - UiT_XP / Di return Minv_.at[r, :].set(Minv_row) M_inv = jax.lax.cond( parent >= 0, update_row, lambda Minv_: Minv_, M_inv, ) Minv_row = M_inv[r, :] # P_i = S_i Minv[r,:] + Xᵢ P_parent Pi = jnp.expand_dims(Si, 1) @ jnp.expand_dims(Minv_row, 0) Pi = Pi + Xi @ P_parent P = P.at[i].set(Pi) return (M_inv, P), None (M_inv, P), _ = jax.lax.scan(loop_forward_pass, forward_pass_carry, idx_fwd) # Symmetrize numerically M_inv = 0.5 * (M_inv + M_inv.T) return M_inv ================================================ FILE: src/jaxsim/rbda/rnea.py ================================================ import jax import jax.numpy as jnp import jaxlie import jaxsim.api as js import jaxsim.typing as jtp from jaxsim.math import STANDARD_GRAVITY, Adjoint, Cross from . import utils def rnea( model: js.model.JaxSimModel, *, base_position: jtp.Vector, base_quaternion: jtp.Vector, joint_positions: jtp.Vector, base_linear_velocity: jtp.Vector, base_angular_velocity: jtp.Vector, joint_velocities: jtp.Vector, base_linear_acceleration: jtp.Vector | None = None, base_angular_acceleration: jtp.Vector | None = None, joint_accelerations: jtp.Vector | None = None, joint_transforms: jtp.MatrixLike, link_forces: jtp.Matrix | None = None, standard_gravity: jtp.FloatLike = STANDARD_GRAVITY, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute inverse dynamics using the Recursive Newton-Euler Algorithm (RNEA). Args: model: The model to consider. base_position: The position of the base link. base_quaternion: The quaternion of the base link. joint_positions: The positions of the joints. base_linear_velocity: The linear velocity of the base link in inertial-fixed representation. base_angular_velocity: The angular velocity of the base link in inertial-fixed representation. joint_velocities: The velocities of the joints. base_linear_acceleration: The linear acceleration of the base link in inertial-fixed representation. base_angular_acceleration: The angular acceleration of the base link in inertial-fixed representation. joint_accelerations: The accelerations of the joints. joint_transforms: The parent-to-child transforms of the joints. link_forces: The forces applied to the links expressed in the world frame. standard_gravity: The standard gravity constant. Returns: A tuple containing the 6D force applied to the base link expressed in the world frame and the joint forces that, when applied respectively to the base link and joints, produce the given base and joint accelerations. """ W_p_B, W_Q_B, s, W_v_WB, ṡ, W_v̇_WB, s̈, _, W_f, W_g = utils.process_inputs( model=model, base_position=base_position, base_quaternion=base_quaternion, joint_positions=joint_positions, base_linear_velocity=base_linear_velocity, base_angular_velocity=base_angular_velocity, joint_velocities=joint_velocities, base_linear_acceleration=base_linear_acceleration, base_angular_acceleration=base_angular_acceleration, joint_accelerations=joint_accelerations, link_forces=link_forces, standard_gravity=standard_gravity, ) W_g = jnp.atleast_2d(W_g).T W_v_WB = jnp.atleast_2d(W_v_WB).T W_v̇_WB = jnp.atleast_2d(W_v̇_WB).T # Get the 6D spatial inertia matrices of all links. M = js.model.link_spatial_inertia_matrices(model=model) # Get the parent array λ(i). # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array # Compute the base transform. W_H_B = jaxlie.SE3.from_rotation_and_translation( rotation=jaxlie.SO3(wxyz=W_Q_B), translation=W_p_B, ) # Compute 6D transforms of the base velocity. W_X_B = W_H_B.adjoint() B_X_W = W_H_B.inverse().adjoint() # Extract the parent-to-child adjoints of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. # Ensure cached transforms are JAX arrays so they work with traced indices. i_X_λi = jnp.asarray(joint_transforms) # Extract the joint motion subspaces. S = model.kin_dyn_parameters.motion_subspaces # Allocate buffers. v = jnp.zeros(shape=(model.number_of_links(), 6, 1)) a = jnp.zeros(shape=(model.number_of_links(), 6, 1)) f = jnp.zeros(shape=(model.number_of_links(), 6, 1)) # Allocate the buffer of transforms link -> base. i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6)) i_X_0 = i_X_0.at[0].set(jnp.eye(6)) # Initialize the acceleration of the base link. a_0 = -B_X_W @ W_g a = a.at[0].set(a_0) if model.floating_base(): # Base velocity v₀ in body-fixed representation. v_0 = B_X_W @ W_v_WB v = v.at[0].set(v_0) # Base acceleration a₀ in body-fixed representation w/o gravity. a_0 = B_X_W @ (W_v̇_WB - W_g) a = a.at[0].set(a_0) # Force applied to the base link that produce the base acceleration w/o gravity. f_0 = ( M[0] @ a[0] + Cross.vx_star(v[0]) @ M[0] @ v[0] - W_X_B.T @ jnp.vstack(W_f[0]) ) f = f.at[0].set(f_0) # ====== # Pass 1 # ====== ForwardPassCarry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix] forward_pass_carry: ForwardPassCarry = (v, a, i_X_0, f) def forward_pass( carry: ForwardPassCarry, i: jtp.Int ) -> tuple[ForwardPassCarry, None]: ii = i - 1 v, a, i_X_0, f = carry # Project the joint velocity into its motion subspace. vJ = S[i] * ṡ[ii] # Propagate the link velocity. v_i = i_X_λi[i] @ v[λ[i]] + vJ v = v.at[i].set(v_i) # Propagate the link acceleration. a_i = i_X_λi[i] @ a[λ[i]] + S[i] * s̈[ii] + Cross.vx(v[i]) @ vJ a = a.at[i].set(a_i) # Compute the link-to-base transform. i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]] i_X_0 = i_X_0.at[i].set(i_X_0_i) # Compute link-to-world transform for the 6D force. i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T # Compute the force acting on the link. f_i = ( M[i] @ a[i] + Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(W_f[i]) ) f = f.at[i].set(f_i) return (v, a, i_X_0, f), None (v, a, i_X_0, f), _ = ( jax.lax.scan( f=forward_pass, init=forward_pass_carry, xs=jnp.arange(start=1, stop=model.number_of_links()), ) if model.number_of_links() > 1 else [(v, a, i_X_0, f), None] ) # ====== # Pass 2 # ====== τ = jnp.zeros_like(s) BackwardPassCarry = tuple[jtp.Vector, jtp.Matrix] backward_pass_carry: BackwardPassCarry = (τ, f) def backward_pass( carry: BackwardPassCarry, i: jtp.Int ) -> tuple[BackwardPassCarry, None]: ii = i - 1 τ, f = carry # Project the 6D force to the DoF of the joint. τ_i = S[i].T @ f[i] τ = τ.at[ii].set(τ_i.squeeze()) # Propagate the force to the parent link. def update_f(f: jtp.Matrix) -> jtp.Matrix: f_λi = f[λ[i]] + i_X_λi[i].T @ f[i] f = f.at[λ[i]].set(f_λi) return f f = jax.lax.cond( pred=jnp.logical_or(λ[i] != 0, model.floating_base()), true_fun=update_f, false_fun=lambda f: f, operand=f, ) return (τ, f), None (τ, f), _ = ( jax.lax.scan( f=backward_pass, init=backward_pass_carry, xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())), ) if model.number_of_links() > 1 else [(τ, f), None] ) # ============== # Adjust outputs # ============== # Express the base 6D force in the world frame. W_f0 = B_X_W.T @ f[0] return W_f0.squeeze(), jnp.atleast_1d(τ.squeeze()) ================================================ FILE: src/jaxsim/rbda/utils.py ================================================ import jax.numpy as jnp import jaxsim.api as js import jaxsim.typing as jtp from jaxsim import exceptions from jaxsim.math import STANDARD_GRAVITY def process_inputs( model: js.model.JaxSimModel, *, base_position: jtp.VectorLike | None = None, base_quaternion: jtp.VectorLike | None = None, joint_positions: jtp.VectorLike | None = None, base_linear_velocity: jtp.VectorLike | None = None, base_angular_velocity: jtp.VectorLike | None = None, joint_velocities: jtp.VectorLike | None = None, base_linear_acceleration: jtp.VectorLike | None = None, base_angular_acceleration: jtp.VectorLike | None = None, joint_accelerations: jtp.VectorLike | None = None, joint_forces: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, standard_gravity: jtp.ScalarLike | None = None, ) -> tuple[ jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector, jtp.Matrix, jtp.Vector, ]: """ Adjust the inputs to rigid-body dynamics algorithms. Args: model: The model to consider. base_position: The position of the base link. base_quaternion: The quaternion of the base link. joint_positions: The positions of the joints. base_linear_velocity: The linear velocity of the base link. base_angular_velocity: The angular velocity of the base link. joint_velocities: The velocities of the joints. base_linear_acceleration: The linear acceleration of the base link. base_angular_acceleration: The angular acceleration of the base link. joint_accelerations: The accelerations of the joints. joint_forces: The forces applied to the joints. link_forces: The forces applied to the links. standard_gravity: The standard gravity constant. Returns: The adjusted inputs. """ dofs = model.dofs() nl = model.number_of_links() # Floating-base position. W_p_B = base_position W_Q_B = base_quaternion s = joint_positions # Floating-base velocity in inertial-fixed representation. W_vl_WB = base_linear_velocity W_ω_WB = base_angular_velocity ṡ = joint_velocities # Floating-base acceleration in inertial-fixed representation. W_v̇l_WB = base_linear_acceleration W_ω̇_WB = base_angular_acceleration s̈ = joint_accelerations # System dynamics inputs. f = link_forces τ = joint_forces # Fill missing data and adjust dimensions. s = jnp.atleast_1d(s.squeeze()) if s is not None else jnp.zeros(dofs) ṡ = jnp.atleast_1d(ṡ.squeeze()) if ṡ is not None else jnp.zeros(dofs) s̈ = jnp.atleast_1d(s̈.squeeze()) if s̈ is not None else jnp.zeros(dofs) τ = jnp.atleast_1d(τ.squeeze()) if τ is not None else jnp.zeros(dofs) W_vl_WB = jnp.atleast_1d(W_vl_WB.squeeze()) if W_vl_WB is not None else jnp.zeros(3) W_v̇l_WB = ( jnp.atleast_1d(W_v̇l_WB.squeeze()) if W_v̇l_WB is not None else jnp.zeros(3) ) W_p_B = jnp.atleast_1d(W_p_B.squeeze()) if W_p_B is not None else jnp.zeros(3) W_ω_WB = jnp.atleast_1d(W_ω_WB.squeeze()) if W_ω_WB is not None else jnp.zeros(3) W_ω̇_WB = jnp.atleast_1d(W_ω̇_WB.squeeze()) if W_ω̇_WB is not None else jnp.zeros(3) f = jnp.atleast_2d(f.squeeze()) if f is not None else jnp.zeros(shape=(nl, 6)) W_Q_B = ( jnp.atleast_1d(W_Q_B.squeeze()) if W_Q_B is not None else jnp.array([1.0, 0, 0, 0]) ) standard_gravity = ( jnp.array(standard_gravity).squeeze() if standard_gravity is not None else STANDARD_GRAVITY ) if s.shape != (dofs,): raise ValueError(s.shape, dofs) if ṡ.shape != (dofs,): raise ValueError(ṡ.shape, dofs) if s̈.shape != (dofs,): raise ValueError(s̈.shape, dofs) if τ.shape != (dofs,): raise ValueError(τ.shape, dofs) if W_p_B.shape != (3,): raise ValueError(W_p_B.shape, (3,)) if W_vl_WB.shape != (3,): raise ValueError(W_vl_WB.shape, (3,)) if W_ω_WB.shape != (3,): raise ValueError(W_ω_WB.shape, (3,)) if W_v̇l_WB.shape != (3,): raise ValueError(W_v̇l_WB.shape, (3,)) if W_ω̇_WB.shape != (3,): raise ValueError(W_ω̇_WB.shape, (3,)) if f.shape != (nl, 6): raise ValueError(f.shape, (nl, 6)) if W_Q_B.shape != (4,): raise ValueError(W_Q_B.shape, (4,)) # Check that the quaternion does not contain NaN values. exceptions.raise_value_error_if( condition=jnp.isnan(W_Q_B).any(), msg="A RBDA received a quaternion that contains NaN values.", ) # Check that the quaternion is unary since our RBDAs make this assumption in order # to prevent introducing additional normalizations that would affect AD. exceptions.raise_value_error_if( condition=~jnp.allclose(W_Q_B.dot(W_Q_B), 1.0), msg="A RBDA received a quaternion that is not normalized.", ) # Pack the 6D base velocity and acceleration. W_v_WB = jnp.hstack([W_vl_WB, W_ω_WB]) W_v̇_WB = jnp.hstack([W_v̇l_WB, W_ω̇_WB]) # Create the 6D gravity acceleration. W_g = jnp.array([0, 0, standard_gravity, 0, 0, 0]) return ( W_p_B.astype(float), W_Q_B.astype(float), s.astype(float), W_v_WB.astype(float), ṡ.astype(float), W_v̇_WB.astype(float), s̈.astype(float), τ.astype(float), f.astype(float), W_g.astype(float), ) ================================================ FILE: src/jaxsim/terrain/__init__.py ================================================ from . import terrain from .terrain import FlatTerrain, PlaneTerrain, Terrain ================================================ FILE: src/jaxsim/terrain/terrain.py ================================================ from __future__ import annotations import abc import dataclasses import jax.numpy as jnp import jax_dataclasses import numpy as np import jaxsim.math import jaxsim.typing as jtp from jaxsim import exceptions class Terrain(abc.ABC): """ Base class for terrain models. Attributes: delta: The delta value used for numerical differentiation. """ delta = 0.010 @abc.abstractmethod def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float: """ Compute the height of the terrain at a specific (x, y) location. Args: x: The x-coordinate of the location. y: The y-coordinate of the location. Returns: The height of the terrain at the specified location. """ pass def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector: """ Compute the normal vector of the terrain at a specific (x, y) location. Args: x: The x-coordinate of the location. y: The y-coordinate of the location. Returns: The normal vector of the terrain surface at the specified location. """ # https://stackoverflow.com/a/5282364 h_xp = self.height(x=x + self.delta, y=y) h_xm = self.height(x=x - self.delta, y=y) h_yp = self.height(x=x, y=y + self.delta) h_ym = self.height(x=x, y=y - self.delta) n = jnp.array( [(h_xm - h_xp) / (2 * self.delta), (h_ym - h_yp) / (2 * self.delta), 1.0] ) return n / jaxsim.math.safe_norm(n, axis=-1) @jax_dataclasses.pytree_dataclass class FlatTerrain(Terrain): """ Represents a terrain model with a flat surface and a constant height. """ _height: float = dataclasses.field(default=0.0, kw_only=True) @staticmethod def build(height: jtp.FloatLike = 0.0) -> FlatTerrain: """ Create a FlatTerrain instance with a specified height. Args: height: The height of the flat terrain. Returns: FlatTerrain: A FlatTerrain instance. """ return FlatTerrain(_height=float(height)) def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float: """ Compute the height of the terrain at a specific (x, y) location. Args: x: The x-coordinate of the location. y: The y-coordinate of the location. Returns: The height of the terrain at the specified location. """ return jnp.array(self._height, dtype=float) def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector: """ Compute the normal vector of the terrain at a specific (x, y) location. Args: x: The x-coordinate of the location. y: The y-coordinate of the location. Returns: The normal vector of the terrain surface at the specified location. """ return jnp.array([0.0, 0.0, 1.0], dtype=float) def __hash__(self) -> int: return hash(self._height) def __eq__(self, other: FlatTerrain) -> bool: if not isinstance(other, FlatTerrain): return False return self._height == other._height @jax_dataclasses.pytree_dataclass class PlaneTerrain(FlatTerrain): """ Represents a terrain model with a flat surface defined by a normal vector. """ _normal: tuple[float, float, float] = jax_dataclasses.field( default=(0.0, 0.0, 1.0), kw_only=True ) @staticmethod def build(height: jtp.FloatLike = 0.0, *, normal: jtp.VectorLike) -> PlaneTerrain: """ Create a PlaneTerrain instance with a specified plane normal vector. Args: normal: The normal vector of the terrain plane. height: The height of the plane over the origin. Returns: PlaneTerrain: A PlaneTerrain instance. """ normal = jnp.array(normal, dtype=float) height = jnp.array(height, dtype=float) if normal.shape != (3,): msg = "Expected a 3D vector for the plane normal, got '{}'." raise ValueError(msg.format(normal.shape)) # Make sure that the plane normal is a unit vector. normal = normal / jnp.linalg.norm(normal) return PlaneTerrain( _height=height.item(), _normal=tuple(normal.tolist()), ) def normal( self, x: jtp.FloatLike | None = None, y: jtp.FloatLike | None = None ) -> jtp.Vector: """ Compute the normal vector of the terrain at a specific (x, y) location. Args: x: The x-coordinate of the location. y: The y-coordinate of the location. Returns: The normal vector of the terrain surface at the specified location. """ return jnp.array(self._normal, dtype=float) def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float: """ Compute the height of the terrain at a specific (x, y) location on a plane. Args: x: The x-coordinate of the location. y: The y-coordinate of the location. Returns: The height of the terrain at the specified location on the plane. """ # Equation of the plane: A x + B y + C z + D = 0 # Normal vector coordinates: (A, B, C) # The height over the origin: -D/C # Get the plane equation coefficients from the terrain normal. A, B, C = self._normal exceptions.raise_value_error_if( condition=jnp.allclose(C, 0.0), msg="The z component of the normal cannot be zero.", ) # Compute the final coefficient D considering the terrain height. D = -C * self._height # Invert the plane equation to get the height at the given (x, y) coordinates. return jnp.array(-(A * x + B * y + D) / C).astype(float) def __hash__(self) -> int: from jaxsim.utils.wrappers import HashedNumpyArray return hash( ( hash(self._height), HashedNumpyArray.hash_of_array( array=np.array(self._normal, dtype=float) ), ) ) def __eq__(self, other: PlaneTerrain) -> bool: if not isinstance(other, PlaneTerrain): return False if not ( np.allclose(self._height, other._height) and np.allclose( np.array(self._normal, dtype=float), np.array(other._normal, dtype=float), ) ): return False return True ================================================ FILE: src/jaxsim/typing.py ================================================ from collections.abc import Hashable from typing import Any, TypeVar import jax # ========= # JAX types # ========= Array = jax.Array Scalar = Array Vector = Array Matrix = Array Int = Scalar Bool = Scalar Float = Scalar PyTree: object = ( dict[Hashable, TypeVar("PyTree")] | list[TypeVar("PyTree")] | tuple[TypeVar("PyTree")] | jax.Array | Any | None ) # ======================= # Mixed JAX / NumPy types # ======================= ArrayLike = jax.typing.ArrayLike | tuple ScalarLike = int | float | Scalar | ArrayLike VectorLike = Vector | ArrayLike | tuple MatrixLike = Matrix | ArrayLike IntLike = int | Int | jax.typing.ArrayLike BoolLike = bool | Bool | jax.typing.ArrayLike FloatLike = float | Float | jax.typing.ArrayLike ================================================ FILE: src/jaxsim/utils/__init__.py ================================================ from jax_dataclasses._copy_and_mutate import _Mutability as Mutability from .jaxsim_dataclass import JaxsimDataclass from .tracing import not_tracing, tracing from .wrappers import HashedNumpyArray, HashlessObject ================================================ FILE: src/jaxsim/utils/jaxsim_dataclass.py ================================================ import abc import contextlib import dataclasses import functools from collections.abc import Callable, Iterator, Sequence from typing import Any, ClassVar import jax.flatten_util import jax_dataclasses import jaxsim.typing as jtp from . import Mutability try: from typing import Self except ImportError: from typing_extensions import Self @jax_dataclasses.pytree_dataclass class JaxsimDataclass(abc.ABC): """Class extending `jax_dataclasses.pytree_dataclass` instances with utilities.""" # This attribute is set by jax_dataclasses __mutability__: ClassVar[Mutability] = Mutability.FROZEN @contextlib.contextmanager def editable(self: Self, validate: bool = True) -> Iterator[Self]: """ Context manager to operate on a mutable copy of the object. Args: validate: Whether to validate the output PyTree upon exiting the context. Yields: A mutable copy of the object. Note: This context manager is useful to operate on an r/w copy of a PyTree making sure that the output object does not trigger JIT recompilations. """ mutability = ( Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION ) with self.copy().mutable_context(mutability=mutability) as obj: yield obj @contextlib.contextmanager def mutable_context( self: Self, mutability: Mutability = Mutability.MUTABLE, restore_after_exception: bool = True, ) -> Iterator[Self]: """ Context manager to temporarily change the mutability of the object. Args: mutability: The mutability to set. restore_after_exception: Whether to restore the original object in case of an exception occurring within the context. Yields: The object with the new mutability. Note: This context manager is useful to operate in place on a PyTree without the need to make a copy while optionally keeping active the checks on the PyTree structure, shapes, and dtypes. """ if restore_after_exception: self_copy = self.copy() original_mutability = self.mutability() original_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self) original_shapes = JaxsimDataclass.get_leaf_shapes(tree=self) original_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self) original_structure = jax.tree.structure(tree=self) def restore_self() -> None: self.set_mutability(mutability=Mutability.MUTABLE_NO_VALIDATION) for f in dataclasses.fields(self_copy): setattr(self, f.name, getattr(self_copy, f.name)) try: self.set_mutability(mutability=mutability) yield self if mutability is not Mutability.MUTABLE_NO_VALIDATION: new_structure = jax.tree.structure(tree=self) if original_structure != new_structure: msg = "Pytree structure has changed from {} to {}" raise ValueError(msg.format(original_structure, new_structure)) new_shapes = JaxsimDataclass.get_leaf_shapes(tree=self) if original_shapes != new_shapes: msg = "Leaves shapes have changed from {} to {}" raise ValueError(msg.format(original_shapes, new_shapes)) new_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self) if original_dtypes != new_dtypes: msg = "Leaves dtypes have changed from {} to {}" raise ValueError(msg.format(original_dtypes, new_dtypes)) new_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self) if original_weak_types != new_weak_types: msg = "Leaves weak types have changed from {} to {}" raise ValueError(msg.format(original_weak_types, new_weak_types)) except Exception as e: if restore_after_exception: restore_self() self.set_mutability(original_mutability) raise e finally: self.set_mutability(original_mutability) @staticmethod def get_leaf_shapes(tree: jtp.PyTree) -> tuple[tuple[int, ...] | None]: """ Get the leaf shapes of a PyTree. Args: tree: The PyTree to consider. Returns: A tuple containing the leaf shapes of the PyTree or `None` is the leaf is not a numpy-like array. """ return tuple( map( lambda leaf: getattr(leaf, "shape", None), jax.tree.leaves(tree), ) ) @staticmethod def get_leaf_dtypes(tree: jtp.PyTree) -> tuple: """ Get the leaf dtypes of a PyTree. Args: tree: The PyTree to consider. Returns: A tuple containing the leaf dtypes of the PyTree or `None` is the leaf is not a numpy-like array. """ return tuple( map( lambda leaf: getattr(leaf, "dtype", None), jax.tree.leaves(tree), ) ) @staticmethod def get_leaf_weak_types(tree: jtp.PyTree) -> tuple[bool, ...]: """ Get the leaf weak types of a PyTree. Args: tree: The PyTree to consider. Returns: A tuple marking whether the leaf contains a JAX array with weak type. """ return tuple( map( lambda leaf: getattr(leaf, "weak_type", None), jax.tree.leaves(tree), ) ) @staticmethod def check_compatibility(*trees: Sequence[Any]) -> None: """ Check whether the PyTrees are compatible in structure, shape, and dtype. Args: *trees: The PyTrees to compare. Raises: ValueError: If the PyTrees have incompatible structures, shapes, or dtypes. """ target_structure = jax.tree.structure(trees[0]) compatible_structure = functools.reduce( lambda compatible, tree: compatible and jax.tree.structure(tree) == target_structure, trees[1:], True, ) if not compatible_structure: raise ValueError( f"Pytrees have incompatible structures.\n" f"Original: {', '.join(map(str, [jax.tree.structure(tree) for tree in trees[1:]]))}\n" f"Target: {target_structure}" ) target_shapes = JaxsimDataclass.get_leaf_shapes(trees[0]) compatible_shapes = functools.reduce( lambda compatible, tree: compatible and JaxsimDataclass.get_leaf_shapes(tree) == target_shapes, trees[1:], True, ) if not compatible_shapes: raise ValueError("Pytrees have incompatible shapes.") target_dtypes = JaxsimDataclass.get_leaf_dtypes(trees[0]) compatible_dtypes = functools.reduce( lambda compatible, tree: compatible and JaxsimDataclass.get_leaf_dtypes(tree) == target_dtypes, trees[1:], True, ) if not compatible_dtypes: raise ValueError("Pytrees have incompatible dtypes.") def is_mutable(self, validate: bool = False) -> bool: """ Check whether the object is mutable. Args: validate: Additionally checks if the object also has validation enabled. Returns: True if the object is mutable, False otherwise. """ return ( self.__mutability__ is Mutability.MUTABLE if validate else self.__mutability__ is Mutability.MUTABLE_NO_VALIDATION ) def mutability(self) -> Mutability: """ Get the mutability type of the object. Returns: The mutability type of the object. """ return self.__mutability__ def set_mutability(self, mutability: Mutability) -> None: """ Set the mutability of the object in-place. Args: mutability: The desired mutability type. """ jax_dataclasses._copy_and_mutate._mark_mutable( self, mutable=mutability, visited=set() ) def mutable(self: Self, mutable: bool = True, validate: bool = False) -> Self: """ Return a mutable reference of the object. Args: mutable: Whether to make the object mutable. validate: Whether to enable validation on the object. Returns: A mutable reference of the object. """ if mutable: mutability = ( Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION ) else: mutability = Mutability.FROZEN self.set_mutability(mutability=mutability) return self def copy(self: Self) -> Self: """ Return a copy of the object. Returns: A copy of the object. """ # Make a copy calling tree_map. obj = jax.tree.map(lambda leaf: leaf, self) # Make sure that the copied object and all the copied leaves have the same # mutability of the original object. obj.set_mutability(mutability=self.mutability()) return obj def replace(self: Self, validate: bool = True, **kwargs) -> Self: """ Return a new object replacing in-place the specified fields with new values. Args: validate: Whether to validate that the new fields do not alter the PyTree. **kwargs: The fields to replace. Returns: A reference of the object with the specified fields replaced. """ # Use the dataclasses replace method. obj = dataclasses.replace(self, **kwargs) if validate: JaxsimDataclass.check_compatibility(self, obj) # Make sure that all the new leaves have the same mutability of the object. obj.set_mutability(mutability=self.mutability()) return obj def flatten(self) -> jtp.Vector: """ Flatten the object into a 1D vector. Returns: A 1D vector containing the flattened object. """ return self.flatten_fn()(self) @classmethod def flatten_fn(cls: type[Self]) -> Callable[[Self], jtp.Vector]: """ Return a function to flatten the object into a 1D vector. Returns: A function to flatten the object into a 1D vector. """ return lambda pytree: jax.flatten_util.ravel_pytree(pytree)[0] def unflatten_fn(self: Self) -> Callable[[jtp.Vector], Self]: """ Return a function to unflatten a 1D vector into the object. Returns: A function to unflatten a 1D vector into the object. Notes: Due to JAX internals, the function to unflatten a PyTree needs to be created from an existing instance of the PyTree. """ return jax.flatten_util.ravel_pytree(self)[1] ================================================ FILE: src/jaxsim/utils/tracing.py ================================================ from typing import Any import jax._src.core import jax.flatten_util import jax.interpreters.partial_eval def tracing(var: Any) -> bool | jax.Array: """Return True if the variable is being traced by JAX, False otherwise.""" return isinstance( var, jax._src.core.Tracer | jax.interpreters.partial_eval.DynamicJaxprTracer ) def not_tracing(var: Any) -> bool | jax.Array: """Return True if the variable is not being traced by JAX, False otherwise.""" return True if tracing(var) is False else False ================================================ FILE: src/jaxsim/utils/wrappers.py ================================================ from __future__ import annotations import dataclasses from collections.abc import Callable from typing import Generic, TypeVar import jax import jax_dataclasses import numpy as np import numpy.typing as npt T = TypeVar("T") @dataclasses.dataclass class HashlessObject(Generic[T]): """ A class that wraps an object and makes it hashless. This is useful for creating particular JAX pytrees. For example, to create a pytree with a static leaf that is ignored by JAX when it compares two instances to trigger a JIT recompilation. """ obj: T def get(self: HashlessObject[T]) -> T: """ Get the wrapped object. """ return self.obj def __hash__(self) -> int: return 0 def __eq__(self, other: HashlessObject[T]) -> bool: if not isinstance(other, HashlessObject) and isinstance( other.get(), type(self.get()) ): return False return hash(self) == hash(other) @dataclasses.dataclass class CustomHashedObject(Generic[T]): """ A class that wraps an object and computes its hash with a custom hash function. """ obj: T hash_function: Callable[[T], int] = hash def get(self: CustomHashedObject[T]) -> T: """ Get the wrapped object. """ return self.obj def __hash__(self) -> int: return self.hash_function(self.obj) def __eq__(self, other: CustomHashedObject[T]) -> bool: if not isinstance(other, CustomHashedObject) and isinstance( other.get(), type(self.get()) ): return False return hash(self) == hash(other) @jax_dataclasses.pytree_dataclass class HashedNumpyArray: """ A class that wraps a numpy array and makes it hashable. This is useful for creating particular JAX pytrees. For example, to create a pytree with a plain NumPy or JAX NumPy array as static leaf. Note: Calculating with the wrapper class the hash of a very large array can be very expensive. If the array is large and only the equality operator is needed, set `large_array=True` to use a faster comparison method. """ array: jax.Array | npt.NDArray precision: float | None = dataclasses.field( default=1e-9, repr=False, compare=False, hash=False ) large_array: jax_dataclasses.Static[bool] = dataclasses.field( default=False, repr=False, compare=False, hash=False ) def get(self) -> jax.Array | npt.NDArray: """ Get the wrapped array. """ return self.array def __hash__(self) -> int: return HashedNumpyArray.hash_of_array( array=self.array, precision=self.precision ) def __eq__(self, other: HashedNumpyArray) -> bool: if not isinstance(other, HashedNumpyArray): return False if self.large_array: return np.allclose( self.array, other.array, **(dict(atol=self.precision) if self.precision is not None else {}), ) return hash(self) == hash(other) @staticmethod def hash_of_array( array: jax.Array | npt.NDArray, precision: float | None = 1e-9 ) -> int: """ Calculate the hash of a NumPy array. Args: array: The array to hash. precision: Optionally limit the precision over which the hash is computed. Returns: The hash of the array. """ array = np.array(array).flatten() array = np.where(array == np.nan, hash(np.nan), array) array = np.where(array == np.inf, hash(np.inf), array) array = np.where(array == -np.inf, hash(-np.inf), array) if precision is not None: integer1 = (array * precision).astype(int) integer2 = (array - integer1 / precision).astype(int) decimal_array = ((array - integer1 * 1e9 - integer2) / precision).astype( int ) array = np.hstack([integer1, integer2, decimal_array]).astype(int) return hash(tuple(array.tolist())) ================================================ FILE: tests/__init__.py ================================================ ================================================ FILE: tests/assets/4_bar_opened.urdf ================================================ ================================================ FILE: tests/assets/cube.stl ================================================ solid model facet normal 0.0 0.0 -1.0 outer loop vertex 20.0 0.0 0.0 vertex 0.0 -20.0 0.0 vertex 0.0 0.0 0.0 endloop endfacet facet normal 0.0 0.0 -1.0 outer loop vertex 0.0 -20.0 0.0 vertex 20.0 0.0 0.0 vertex 20.0 -20.0 0.0 endloop endfacet facet normal -0.0 -1.0 -0.0 outer loop vertex 20.0 -20.0 20.0 vertex 0.0 -20.0 0.0 vertex 20.0 -20.0 0.0 endloop endfacet facet normal -0.0 -1.0 -0.0 outer loop vertex 0.0 -20.0 0.0 vertex 20.0 -20.0 20.0 vertex 0.0 -20.0 20.0 endloop endfacet facet normal 1.0 0.0 0.0 outer loop vertex 20.0 0.0 0.0 vertex 20.0 -20.0 20.0 vertex 20.0 -20.0 0.0 endloop endfacet facet normal 1.0 0.0 0.0 outer loop vertex 20.0 -20.0 20.0 vertex 20.0 0.0 0.0 vertex 20.0 0.0 20.0 endloop endfacet facet normal -0.0 -0.0 1.0 outer loop vertex 20.0 -20.0 20.0 vertex 0.0 0.0 20.0 vertex 0.0 -20.0 20.0 endloop endfacet facet normal -0.0 -0.0 1.0 outer loop vertex 0.0 0.0 20.0 vertex 20.0 -20.0 20.0 vertex 20.0 0.0 20.0 endloop endfacet facet normal -1.0 0.0 0.0 outer loop vertex 0.0 0.0 20.0 vertex 0.0 -20.0 0.0 vertex 0.0 -20.0 20.0 endloop endfacet facet normal -1.0 0.0 0.0 outer loop vertex 0.0 -20.0 0.0 vertex 0.0 0.0 20.0 vertex 0.0 0.0 0.0 endloop endfacet facet normal -0.0 1.0 0.0 outer loop vertex 0.0 0.0 20.0 vertex 20.0 0.0 0.0 vertex 0.0 0.0 0.0 endloop endfacet facet normal -0.0 1.0 0.0 outer loop vertex 20.0 0.0 0.0 vertex 0.0 0.0 20.0 vertex 20.0 0.0 20.0 endloop endfacet endsolid model ================================================ FILE: tests/assets/double_pendulum.sdf ================================================ world base_link 1 0 0 -5 5 100 100 0.0 0 0.0 0 0 0 0 0 0 100 1 0 0 1 0 1 0 0 1 0 0 0 0.20 0.20 2.15 0 0 1 0 0 0 0.20 0.20 2.15 0.20 0 2 -3.1415 0 0 base_link right_link 1 0 0 -100 100 100 100 1.0 0 0.0 0 0 0 0 0 0 0 0 0 0.5 0 0 0 1 1.0 0 0 1.0 0 1.0 0 0 0.5 0 0 0 0.20 0.20 1.0 -0.20 0 2 -3.1415 0 0 base_link left_link 1 0 0 -100 100 100 100 1.0 0 0.0 0 0 0 0 0 0 0 0.0 0 0.5 0 0 0 1 1.0 0 0 1.0 0 1.0 0.0 0 0.5 0 0 0 0.20 0.20 1.0 0.20 0 1 0 0 0 -0.20 0 1 0 0 0 -0.2 0 1 3.14 0 0 0.2 0 1 3.14 0 0 ================================================ FILE: tests/assets/mixed_shapes_robot.urdf ================================================ ================================================ FILE: tests/assets/test_cube.urdf ================================================ ================================================ FILE: tests/conftest.py ================================================ import os os.environ["JAXSIM_ENABLE_EXCEPTIONS"] = "1" import pathlib import subprocess import jax import numpy as np import pytest import rod import rod.urdf.exporter import jaxsim import jaxsim.api as js from jaxsim.api.model import IntegratorType def pytest_addoption(parser): parser.addoption( "--gpu-only", action="store_true", default=False, help="Run tests only if GPU is available and utilized", ) parser.addoption( "--batch-size", action="store", default="None", help="Batch size for vectorized benchmarks (only applies to benchmark tests)", ) def pytest_generate_tests(metafunc): if ( "batch_size" in metafunc.fixturenames and (batch_size := metafunc.config.getoption("--batch-size")) != "None" ): metafunc.parametrize("batch_size", [1, int(batch_size)]) def check_gpu_usage(): # Set environment variable to prioritize GPU. os.environ["JAX_PLATFORM_NAME"] = "gpu" # Run a simple JAX operation x = jax.device_put(jax.numpy.ones((512, 512))) y = jax.device_put(jax.numpy.ones((512, 512))) _ = jax.numpy.dot(x, y).block_until_ready() # Check GPU memory usage with nvidia-smi. result = subprocess.run( ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,noheader"], capture_output=True, text=True, ) if result.returncode != 0: pytest.exit( "Failed to query GPU usage. Ensure nvidia-smi is installed and accessible." ) gpu_memory_usage = [ int(line.strip().split()[0]) for line in result.stdout.splitlines() ] if all(usage == 0 for usage in gpu_memory_usage): pytest.exit( "GPU is available but not utilized during computations. Check your JAX installation." ) def pytest_configure(config) -> None: """Pytest configuration hook.""" # This is a global variable that is updated by the `prng_key` fixture. pytest.prng_key = jax.random.PRNGKey( seed=int(os.environ.get("JAXSIM_TEST_SEED", 0)) ) # Check if GPU is available and utilized. if config.getoption("--gpu-only"): devices = jax.devices() if not any(device.platform == "gpu" for device in devices): pytest.exit("No GPU devices found. Check your JAX installation.") # Ensure GPU is being used during computation check_gpu_usage() def load_model_from_file(file_path: pathlib.Path, is_urdf=False) -> rod.Sdf: """ Load an SDF or URDF model from a file. Args: file_path: The path to the model file. is_urdf: Whether the file is in URDF or SDF format. Returns: The corresponding rod model. """ return rod.Sdf.load(file_path, is_urdf=is_urdf) # ================ # Generic fixtures # ================ @pytest.fixture(scope="function") def prng_key() -> jax.Array: """ Fixture to generate a new PRNG key for each test function. Returns: The new PRNG key passed to the test. Note: This fixture operates on a global variable initialized in the `pytest_configure` hook. """ pytest.prng_key, subkey = jax.random.split(pytest.prng_key, num=2) return subkey @pytest.fixture( scope="function", params=[ pytest.param(jaxsim.VelRepr.Inertial, id="inertial"), pytest.param(jaxsim.VelRepr.Body, id="body"), pytest.param(jaxsim.VelRepr.Mixed, id="mixed"), ], ) def velocity_representation(request) -> jaxsim.VelRepr: """ Parametrized fixture providing all supported velocity representations. Returns: A velocity representation. """ return request.param @pytest.fixture( scope="function", params=[ pytest.param(IntegratorType.SemiImplicitEuler, id="semi_implicit_euler"), pytest.param(IntegratorType.RungeKutta4, id="runge_kutta_4"), pytest.param(IntegratorType.RungeKutta4Fast, id="runge_kutta_4_fast"), ], ) def integrator(request) -> str: """ Fixture providing the integrator to use in the simulation. Returns: The integrator to use in the simulation. """ return request.param @pytest.fixture(scope="session") def batch_size(request) -> int: """ Fixture providing the batch size for vectorized benchmarks. Returns: The batch size for vectorized benchmarks. """ return 1 # ================================ # Fixtures providing JaxSim models # ================================ # All the fixtures in this section must have "session" scope. # In this way, the models are generated only once and shared among all the tests. # This is not a fixture. def build_jaxsim_model( model_description: str | pathlib.Path | rod.Model, ) -> js.model.JaxSimModel: """ Build a JaxSim model from a model description. Args: model_description: A model description provided by any fixture provider. Returns: A JaxSim model built from the provided description. """ # Build the JaxSim model. model = js.model.JaxSimModel.build_from_model_description( model_description=model_description, ) return model @pytest.fixture(scope="session") def jaxsim_model_box() -> js.model.JaxSimModel: """ Fixture providing the JaxSim model of a box. Returns: The JaxSim model of a box. """ import rod.builder.primitives import rod.urdf.exporter # Create on-the-fly a ROD model of a box. rod_model = ( rod.builder.primitives.BoxBuilder(x=0.3, y=0.2, z=0.1, mass=1.0, name="box") .build_model() .add_link(name="box_link") .add_inertial() .add_visual() .add_collision() .build() ) rod_model.add_frame( rod.Frame( name="box_frame", attached_to="box_link", pose=rod.Pose(relative_to="box_link", pose=[1, 1, 1, 0.5, 0.4, 0.3]), ) ) # Export the URDF string. urdf_string = rod.urdf.exporter.UrdfExporter( pretty=True, gazebo_preserve_fixed_joints=True ).to_urdf_string(sdf=rod_model) return build_jaxsim_model(model_description=urdf_string) @pytest.fixture(scope="session") def jaxsim_model_sphere() -> js.model.JaxSimModel: """ Fixture providing the JaxSim model of a sphere. Returns: The JaxSim model of a sphere. """ import rod.builder.primitives import rod.urdf.exporter # Create on-the-fly a ROD model of a sphere. rod_model = ( rod.builder.primitives.SphereBuilder(radius=0.1, mass=1.0, name="sphere") .build_model() .add_link() .add_inertial() .add_visual() .add_collision() .build() ) # Export the URDF string. urdf_string = rod.urdf.exporter.UrdfExporter(pretty=True).to_urdf_string( sdf=rod_model ) return build_jaxsim_model(model_description=urdf_string) @pytest.fixture(scope="session") def ergocub_model_description_path() -> pathlib.Path: """ Fixture providing the path to the URDF model description of the ErgoCub robot. Returns: The path to the URDF model description of the ErgoCub robot. """ try: os.environ["ROBOT_DESCRIPTION_COMMIT"] = "v0.7.7" import robot_descriptions.ergocub_description finally: _ = os.environ.pop("ROBOT_DESCRIPTION_COMMIT", None) model_urdf_path = pathlib.Path( robot_descriptions.ergocub_description.URDF_PATH.replace( "ergoCubSN002", "ergoCubSN001" ) ) return model_urdf_path @pytest.fixture(scope="session") def jaxsim_model_ergocub( ergocub_model_description_path: pathlib.Path, ) -> js.model.JaxSimModel: """ Fixture providing the JaxSim model of the ErgoCub robot. Returns: The JaxSim model of the ErgoCub robot. """ return build_jaxsim_model(model_description=ergocub_model_description_path) @pytest.fixture(scope="session") def jaxsim_model_ergocub_reduced(jaxsim_model_ergocub) -> js.model.JaxSimModel: """ Fixture providing the JaxSim model of the ErgoCub robot with only locomotion joints. Returns: The JaxSim model of the ErgoCub robot with only locomotion joints. """ model_full = jaxsim_model_ergocub # Get the names of the joints to keep. reduced_joints = tuple( j for j in model_full.joint_names() if "camera" not in j # Remove head and hands. and "neck" not in j and "wrist" not in j and "thumb" not in j and "index" not in j and "middle" not in j and "ring" not in j and "pinkie" not in j # Remove upper body. and "torso" not in j and "elbow" not in j and "shoulder" not in j ) model = js.model.reduce(model=model_full, considered_joints=reduced_joints) return model @pytest.fixture(scope="session") def jaxsim_model_ur10() -> js.model.JaxSimModel: """ Fixture providing the JaxSim model of the UR10 robot. Returns: The JaxSim model of the UR10 robot. """ import robot_descriptions.ur10_description model_urdf_path = pathlib.Path(robot_descriptions.ur10_description.URDF_PATH) return build_jaxsim_model(model_description=model_urdf_path) @pytest.fixture(scope="session") def jaxsim_model_single_pendulum() -> js.model.JaxSimModel: """ Fixture providing the JaxSim model of a single pendulum. Returns: The JaxSim model of a single pendulum. """ import rod.builder.primitives base_height = 2.15 upper_height = 1.0 # =================== # Create the builders # =================== base_builder = rod.builder.primitives.BoxBuilder( name="base", mass=1.0, x=0.15, y=0.15, z=base_height, ) upper_builder = rod.builder.primitives.BoxBuilder( name="upper", mass=0.5, x=0.15, y=0.15, z=upper_height, ) # ================= # Create the joints # ================= fixed = rod.Joint( name="fixed_joint", type="fixed", parent="world", child=base_builder.name, ) pivot = rod.Joint( name="upper_joint", type="continuous", parent=base_builder.name, child=upper_builder.name, axis=rod.Axis( xyz=rod.Xyz([1, 0, 0]), ), ) # ================ # Create the links # ================ base = ( base_builder.build_link( name=base_builder.name, pose=rod.builder.primitives.PrimitiveBuilder.build_pose( pos=np.array([0, 0, base_height / 2]) ), ) .add_inertial() .add_visual() .add_collision() .build() ) upper_pose = rod.builder.primitives.PrimitiveBuilder.build_pose( pos=np.array([0, 0, upper_height / 2]) ) upper = ( upper_builder.build_link( name=upper_builder.name, pose=rod.builder.primitives.PrimitiveBuilder.build_pose( relative_to=base.name, pos=np.array([0, 0, upper_height]) ), ) .add_inertial(pose=upper_pose) .add_visual(pose=upper_pose) .add_collision(pose=upper_pose) .build() ) rod_model = rod.Sdf( version="1.10", model=rod.Model( name="single_pendulum", link=[base, upper], joint=[fixed, pivot], ), ) rod_model.model.resolve_frames() urdf_string = rod.urdf.exporter.UrdfExporter(pretty=True).to_urdf_string( sdf=rod_model.models()[0] ) model = build_jaxsim_model(model_description=urdf_string) return model @pytest.fixture(scope="session") def jaxsim_model_garpez() -> js.model.JaxSimModel: """Fixture to create the original (unscaled) Garpez model.""" rod_model = create_scalable_garpez_model() urdf_string = rod.urdf.exporter.UrdfExporter(pretty=True).to_urdf_string( sdf=rod_model ) return build_jaxsim_model(model_description=urdf_string) @pytest.fixture(scope="session") def jaxsim_model_garpez_scaled(request) -> js.model.JaxSimModel: """Fixture to create the scaled version of the Garpez model.""" # Get the link scales from the request. link1_scale = request.param.get("link1_scale", 1.0) link2_scale = request.param.get("link2_scale", 1.0) link3_scale = request.param.get("link3_scale", 1.0) link4_scale = request.param.get("link4_scale", 1.0) rod_model = create_scalable_garpez_model( link1_scale=link1_scale, link2_scale=link2_scale, link3_scale=link3_scale, link4_scale=link4_scale, ) urdf_string = rod.urdf.exporter.UrdfExporter(pretty=True).to_urdf_string( sdf=rod_model ) return build_jaxsim_model(model_description=urdf_string) def create_scalable_garpez_model( link1_scale: float = 1.0, link2_scale: float = 1.0, link3_scale: float = 1.0, link4_scale: float = 1.0, ) -> rod.Model: """ Build a scalable rod model to test parameterization and scaling. Args: link1_scale: Scale factor for link 1. link2_scale: Scale factor for link 2. link3_scale: Scale factor for link 3. link4_scale: Scale factor for link 4. Returns: A rod model with the specified link scales. Note: The model is built assuming a constant link density, hence scaling the link will also have an impact on the link mass. """ import numpy as np from rod.builder import primitives # ======================== # Create the link builders # ======================== density = 1000.0 # Fixed density in kg/m^3 l1_x, l1_y, l1_z = 0.3 * link1_scale, 0.2, 0.2 l1_volume = l1_x * l1_y * l1_z l1_mass = density * l1_volume link1_builder = primitives.BoxBuilder( name="link1", mass=l1_mass, x=l1_x, y=l1_y, z=l1_z ) l2_radius = 0.1 * link2_scale l2_volume = 4 / 3 * np.pi * l2_radius**3 l2_mass = density * l2_volume link2_builder = primitives.SphereBuilder( name="link2", mass=l2_mass, radius=l2_radius ) l3_radius = 0.05 l3_length = 0.5 * link3_scale l3_volume = np.pi * l3_radius**2 * l3_length l3_mass = density * l3_volume link3_builder = primitives.CylinderBuilder( name="link3", mass=l3_mass, radius=l3_radius, length=l3_length ) l4_x, l4_y, l4_z = 0.3 * link4_scale, 0.2, 0.1 l4_volume = l4_x * l4_y * l4_z l4_mass = density * l4_volume link4_builder = primitives.BoxBuilder( name="link4", mass=l4_mass, x=l4_x, y=l4_y, z=l4_z ) # ================= # Create the joints # ================= link1_to_link2 = rod.Joint( name="link1_to_link2", type="revolute", parent=link1_builder.name, child=link2_builder.name, pose=primitives.PrimitiveBuilder.build_pose( relative_to=link1_builder.name, pos=np.array([link1_builder.x, link1_builder.y / 2, link1_builder.z / 2]), ), axis=rod.Axis(xyz=rod.Xyz(xyz=[0, 1, 0]), limit=rod.Limit()), ) link2_to_link3 = rod.Joint( name="link2_to_link3", type="revolute", parent=link2_builder.name, child=link3_builder.name, pose=primitives.PrimitiveBuilder.build_pose( relative_to=link2_builder.name, pos=np.array([link2_builder.radius, 0, -link2_builder.radius]), ), axis=rod.Axis(xyz=rod.Xyz(xyz=[0, 0, 1]), limit=rod.Limit()), ) link3_to_link4 = rod.Joint( name="link3_to_link4", type="revolute", parent=link3_builder.name, child=link4_builder.name, pose=primitives.PrimitiveBuilder.build_pose( relative_to=link3_builder.name, pos=np.array([-link3_builder.radius, 0, -link3_builder.length]), ), axis=rod.Axis(xyz=rod.Xyz(xyz=[1, 0, 0]), limit=rod.Limit()), ) # ================ # Create the links # ================ link1_elements_pose = primitives.PrimitiveBuilder.build_pose( pos=np.array([link1_builder.x, link1_builder.y, link1_builder.z]) / 2 ) link1 = ( link1_builder.build_link( name=link1_builder.name, pose=primitives.PrimitiveBuilder.build_pose(relative_to="__model__"), ) .add_inertial(pose=link1_elements_pose) .add_visual(pose=link1_elements_pose) .add_collision(pose=link1_elements_pose) .build() ) link2_elements_pose = primitives.PrimitiveBuilder.build_pose( pos=np.array([link2_builder.radius, 0, 0]) ) link2 = ( link2_builder.build_link( name=link2_builder.name, pose=primitives.PrimitiveBuilder.build_pose( relative_to=link1_to_link2.name ), ) .add_inertial(pose=link2_elements_pose) .add_visual(pose=link2_elements_pose) .add_collision(pose=link2_elements_pose) .build() ) link3_elements_pose = primitives.PrimitiveBuilder.build_pose( pos=np.array([0, 0, -link3_builder.length / 2]) ) link3 = ( link3_builder.build_link( name=link3_builder.name, pose=primitives.PrimitiveBuilder.build_pose( relative_to=link2_to_link3.name ), ) .add_inertial(pose=link3_elements_pose) .add_visual(pose=link3_elements_pose) .add_collision(pose=link3_elements_pose) .build() ) link4_elements_pose = primitives.PrimitiveBuilder.build_pose( # pos=np.array([0, 0, -link4_builder.z / 2]) pos=np.array([link4_builder.x / 2, 0, -link4_builder.z / 2]) ) link4 = ( link4_builder.build_link( name=link4_builder.name, pose=primitives.PrimitiveBuilder.build_pose( relative_to=link3_to_link4.name ), ) .add_inertial(pose=link4_elements_pose) .add_visual(pose=link4_elements_pose) .add_collision(pose=link4_elements_pose) .build() ) # =========== # Build model # =========== # Create model rod_model = rod.Model( name="model_demo", canonical_link=link1.name, link=[link1, link2, link3, link4], joint=[link1_to_link2, link2_to_link3, link3_to_link4], ) rod_model.switch_frame_convention( frame_convention=rod.FrameConvention.Urdf, explicit_frames=True, attach_frames_to_links=True, ) assert rod.Sdf(model=rod_model, version="1.10").serialize(validate=True) return rod_model def create_model_with_missing_collision() -> rod.Model: """ Build a rod model with a link that has a visual but no collision element. This model is used to test the export logic when collision elements are missing. Returns: A rod model with one link missing a collision element. """ import numpy as np from rod.builder import primitives density = 1000.0 # Fixed density in kg/m^3 # Create link1 with both visual and collision l1_x, l1_y, l1_z = 0.3, 0.2, 0.2 l1_volume = l1_x * l1_y * l1_z l1_mass = density * l1_volume link1_builder = primitives.BoxBuilder( name="link1", mass=l1_mass, x=l1_x, y=l1_y, z=l1_z ) # Create link2 with visual but WITHOUT collision l2_radius = 0.1 l2_volume = 4 / 3 * np.pi * l2_radius**3 l2_mass = density * l2_volume link2_builder = primitives.SphereBuilder( name="link2", mass=l2_mass, radius=l2_radius ) # Create joint link1_to_link2 = rod.Joint( name="link1_to_link2", type="revolute", parent=link1_builder.name, child=link2_builder.name, pose=primitives.PrimitiveBuilder.build_pose( relative_to=link1_builder.name, pos=np.array([link1_builder.x, link1_builder.y / 2, link1_builder.z / 2]), ), axis=rod.Axis(xyz=rod.Xyz(xyz=[0, 1, 0]), limit=rod.Limit()), ) # Build link1 with visual and collision link1_elements_pose = primitives.PrimitiveBuilder.build_pose( pos=np.array([link1_builder.x, link1_builder.y, link1_builder.z]) / 2 ) link1 = ( link1_builder.build_link( name=link1_builder.name, pose=primitives.PrimitiveBuilder.build_pose(relative_to="__model__"), ) .add_inertial(pose=link1_elements_pose) .add_visual(pose=link1_elements_pose) .add_collision(pose=link1_elements_pose) .build() ) # Build link2 with visual but NO collision link2_elements_pose = primitives.PrimitiveBuilder.build_pose( pos=np.array([link2_builder.radius, 0, 0]) ) link2 = ( link2_builder.build_link( name=link2_builder.name, pose=primitives.PrimitiveBuilder.build_pose( relative_to=link1_to_link2.name ), ) .add_inertial(pose=link2_elements_pose) .add_visual(pose=link2_elements_pose) # Note: NO .add_collision() call here .build() ) # Create model rod_model = rod.Model( name="model_missing_collision", canonical_link=link1.name, link=[link1, link2], joint=[link1_to_link2], ) rod_model.switch_frame_convention( frame_convention=rod.FrameConvention.Urdf, explicit_frames=True, attach_frames_to_links=True, ) assert rod.Sdf(model=rod_model, version="1.10").serialize(validate=True) return rod_model @pytest.fixture(scope="session") def jaxsim_model_missing_collision() -> js.model.JaxSimModel: """ Fixture to create a model with a link that has a visual but no collision element. This is used to test the export logic when collision elements are missing. """ rod_model = create_model_with_missing_collision() urdf_string = rod.urdf.exporter.UrdfExporter(pretty=True).to_urdf_string( sdf=rod_model ) return build_jaxsim_model(model_description=urdf_string) @pytest.fixture(scope="session") def jaxsim_model_double_pendulum() -> js.model.JaxSimModel: """ Fixture providing the JaxSim model of a double pendulum. Returns: The JaxSim model of a double pendulum. """ model_path = pathlib.Path(__file__).parent / "assets" / "double_pendulum.sdf" rod_model = load_model_from_file(model_path) model = build_jaxsim_model(model_description=rod_model) return model @pytest.fixture(scope="session") def jaxsim_model_cartpole() -> js.model.JaxSimModel: """ Fixture providing the JaxSim model of a cartpole. Returns: The JaxSim model of a cartpole. """ model_path = ( pathlib.Path(__file__).parent.parent / "examples" / "assets" / "cartpole.urdf" ) rod_model = load_model_from_file(model_path, is_urdf=True) model = build_jaxsim_model(model_description=rod_model) return model @pytest.fixture(scope="session") def jaxsim_model_4_bar_linkage() -> js.model.JaxSimModel: """ Fixture providing the JaxSim model of a 4-bar linkage (opened configuration). Returns: The JaxSim model of the 4-bar linkage. """ model_path = pathlib.Path(__file__).parent / "assets" / "4_bar_opened.urdf" rod_model = load_model_from_file(model_path, is_urdf=True) model = build_jaxsim_model(model_description=rod_model) return model # ============================ # Collections of JaxSim models # ============================ def get_jaxsim_model_fixture( model_name: str, request: pytest.FixtureRequest ) -> str | pathlib.Path: """ Get the fixture providing the JaxSim model of a robot. Args: model_name: The name of the model. request: The request object. Returns: The JaxSim model of the robot. """ match model_name: case "box": return request.getfixturevalue(jaxsim_model_box.__name__) case "sphere": return request.getfixturevalue(jaxsim_model_sphere.__name__) case "ergocub": return request.getfixturevalue(jaxsim_model_ergocub.__name__) case "ergocub_reduced": return request.getfixturevalue(jaxsim_model_ergocub_reduced.__name__) case "ur10": return request.getfixturevalue(jaxsim_model_ur10.__name__) case "single_pendulum": return request.getfixturevalue(jaxsim_model_single_pendulum.__name__) case "garpez": return request.getfixturevalue(jaxsim_model_garpez.__name__) case "garpez_scaled": return request.getfixturevalue(jaxsim_model_garpez_scaled.__name__) case _: raise ValueError(model_name) @pytest.fixture( scope="session", params=[ "box", "sphere", "ur10", "ergocub", "ergocub_reduced", ], ) def jaxsim_models_all(request) -> pathlib.Path | str: """ Fixture providing the JaxSim models of all supported robots. """ model_name: str = request.param return get_jaxsim_model_fixture(model_name=model_name, request=request) @pytest.fixture( scope="session", params=[ "box", "ur10", "ergocub_reduced", ], ) def jaxsim_models_types(request) -> pathlib.Path | str: """ Fixture providing JaxSim models of all types of supported robots. Note: At the moment, most of our tests use this fixture. It provides: - A robot with no joints. - A fixed-base robot. - A floating-base robot. """ model_name: str = request.param return get_jaxsim_model_fixture(model_name=model_name, request=request) @pytest.fixture( scope="session", params=[ "box", "sphere", ], ) def jaxsim_models_no_joints(request) -> pathlib.Path | str: """ Fixture providing JaxSim models of robots with no joints. """ model_name: str = request.param return get_jaxsim_model_fixture(model_name=model_name, request=request) @pytest.fixture( scope="session", params=[ "ergocub", "ergocub_reduced", ], ) def jaxsim_models_floating_base(request) -> pathlib.Path | str: """ Fixture providing JaxSim models of floating-base robots. """ model_name: str = request.param return get_jaxsim_model_fixture(model_name=model_name, request=request) @pytest.fixture( scope="session", params=[ "ur10", ], ) def jaxsim_models_fixed_base(request) -> pathlib.Path | str: """ Fixture providing JaxSim models of fixed-base robots. """ model_name: str = request.param return get_jaxsim_model_fixture(model_name=model_name, request=request) @pytest.fixture(scope="function") def set_jax_32bit(monkeypatch): """ Fixture that temporarily sets JAX precision to 32-bit for the duration of the test. """ del globals()["jaxsim"] del globals()["js"] # Temporarily disable x64 monkeypatch.setenv("JAX_ENABLE_X64", "0") @pytest.fixture(scope="function") def jaxsim_model_box_32bit(set_jax_32bit, request) -> js.model.JaxSimModel: """ Fixture providing the JaxSim model of a box with 32-bit precision. Returns: The JaxSim model of a box with 32-bit precision. """ return get_jaxsim_model_fixture(model_name="box", request=request) ================================================ FILE: tests/test_actuation.py ================================================ import jax.numpy as jnp from numpy.testing import assert_array_less import jaxsim.api as js import jaxsim.rbda from jaxsim import VelRepr from .utils import assert_allclose def test_tn_curve(jaxsim_model_single_pendulum: js.model.JaxSimModel): model = jaxsim_model_single_pendulum new_act_params = jaxsim.rbda.actuation.ActuationParams() with new_act_params.editable(validate=False) as new_act_params: new_act_params.torque_max = 10 new_act_params.omega_th = 1 new_act_params.omega_max = 2 with model.editable(validate=False) as model: model.actuation_params = new_act_params data = js.data.JaxSimModelData.build( model=model, velocity_representation=VelRepr.Inertial, ) new_joint_velocities = 1.5 * jnp.ones(model.dofs()) joint_torques_0 = 30 * jnp.ones(model.dofs()) data_0 = data.replace(model=model, joint_velocities=new_joint_velocities) τ_total = js.actuation_model.compute_resultant_torques( model, data_0, joint_force_references=joint_torques_0 ) assert_array_less(τ_total, joint_torques_0) new_joint_velocities = 2.5 * jnp.ones(model.dofs()) joint_torques_0 = 30 * jnp.ones(model.dofs()) data_0 = data.replace(model=model, joint_velocities=new_joint_velocities) τ_total = js.actuation_model.compute_resultant_torques( model, data_0, joint_force_references=joint_torques_0 ) assert_allclose(τ_total, 0.0) ================================================ FILE: tests/test_api_com.py ================================================ import jax import jaxsim.api as js from jaxsim import VelRepr from . import utils from .utils import assert_allclose def test_com_properties( jaxsim_models_types: js.model.JaxSimModel, velocity_representation: VelRepr, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=velocity_representation ) kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data) # ===== # Tests # ===== p_com_idt = kin_dyn.com_position() p_com_js = js.com.com_position(model=model, data=data) assert_allclose(p_com_idt, p_com_js) J_Gh_idt = kin_dyn.centroidal_momentum_jacobian() J_Gh_js = js.com.centroidal_momentum_jacobian(model=model, data=data) assert_allclose(J_Gh_idt, J_Gh_js) h_com_idt = kin_dyn.centroidal_momentum() h_com_js = js.com.centroidal_momentum(model=model, data=data) assert_allclose(h_com_idt, h_com_js) M_com_locked_idt = kin_dyn.locked_centroidal_spatial_inertia() M_com_locked_js = js.com.locked_centroidal_spatial_inertia(model=model, data=data) assert_allclose(M_com_locked_idt, M_com_locked_js) J_avg_com_idt = kin_dyn.average_centroidal_velocity_jacobian() J_avg_com_js = js.com.average_centroidal_velocity_jacobian(model=model, data=data) assert_allclose(J_avg_com_idt, J_avg_com_js) v_avg_com_idt = kin_dyn.average_centroidal_velocity() v_avg_com_js = js.com.average_centroidal_velocity(model=model, data=data) assert_allclose(v_avg_com_idt, v_avg_com_js) # https://github.com/gbionics/jaxsim/pull/117#discussion_r1535486123 if data.velocity_representation is not VelRepr.Body: vl_com_idt = kin_dyn.com_velocity() vl_com_js = js.com.com_linear_velocity(model=model, data=data) assert_allclose(vl_com_idt, vl_com_js) # iDynTree provides the bias acceleration in G[W] frame regardless of the velocity # representation. JaxSim, instead, returns the bias acceleration in G[B] when the # active representation is VelRepr.Body. if data.velocity_representation is not VelRepr.Body: G_v̇_bias_WG_idt = kin_dyn.com_bias_acceleration() G_v̇_bias_WG_js = js.com.bias_acceleration(model=model, data=data) assert_allclose(G_v̇_bias_WG_idt, G_v̇_bias_WG_js) ================================================ FILE: tests/test_api_contact.py ================================================ import jax import jax.numpy as jnp import rod import jaxsim.api as js from jaxsim import VelRepr from .utils import assert_allclose def test_contact_kinematics( jaxsim_models_types: js.model.JaxSimModel, velocity_representation: VelRepr, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=velocity_representation, ) # Get the indices of the enabled collidable points. indices_of_enabled_collidable_points = ( model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points ) parent_link_idx_of_enabled_collidable_points = jnp.array( model.kin_dyn_parameters.contact_parameters.body, dtype=int )[indices_of_enabled_collidable_points] # ===== # Tests # ===== # Compute the pose of the implicit contact frame associated to the collidable points # and the transforms of all links. W_H_C = js.contact.transforms(model=model, data=data) W_H_L = data._link_transforms # Check that the orientation of the implicit contact frame matches with the # orientation of the link to which the contact point is attached. for contact_idx, index_of_parent_link in enumerate( parent_link_idx_of_enabled_collidable_points ): assert_allclose( W_H_C[contact_idx, 0:3, 0:3], W_H_L[index_of_parent_link][0:3, 0:3] ) # Check that the origin of the implicit contact frame is located over the # collidable point. W_p_C = js.contact.collidable_point_positions(model=model, data=data) assert_allclose(W_p_C, W_H_C[:, 0:3, 3]) # Compute the velocity of the collidable point. # This quantity always matches with the linear component of the mixed 6D velocity # of the implicit frame associated to the collidable point. W_ṗ_C = js.contact.collidable_point_velocities(model=model, data=data) # Compute the velocity of the collidable point using the contact Jacobian. ν = data.generalized_velocity CW_J_WC = js.contact.jacobian(model=model, data=data, output_vel_repr=VelRepr.Mixed) CW_vl_WC = jnp.einsum("c6g,g->c6", CW_J_WC, ν)[:, 0:3] # Compare the two velocities. assert_allclose(W_ṗ_C, CW_vl_WC) def test_collidable_point_jacobians( jaxsim_models_types: js.model.JaxSimModel, velocity_representation: VelRepr, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=velocity_representation ) # ===== # Tests # ===== # Compute the velocity of the collidable points with a RBDA. # This function always returns the linear part of the mixed velocity of the # implicit frame C corresponding to the collidable point. W_ṗ_C = js.contact.collidable_point_velocities(model=model, data=data) # Compute the generalized velocity and the free-floating Jacobian of the frame C. ν = data.generalized_velocity CW_J_WC = js.contact.jacobian(model=model, data=data, output_vel_repr=VelRepr.Mixed) # Compute the velocity of the collidable points using the Jacobians. v_WC_from_jax = jax.vmap(lambda J, ν: J @ ν, in_axes=(0, None))(CW_J_WC, ν) assert_allclose(W_ṗ_C, v_WC_from_jax[:, 0:3]) def test_contact_jacobian_derivative( jaxsim_models_floating_base: js.model.JaxSimModel, velocity_representation: VelRepr, prng_key: jax.Array, ): model = jaxsim_models_floating_base _, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=velocity_representation, ) # Get the indices of the enabled collidable points. indices_of_enabled_collidable_points = ( model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points ) # Extract the parent link names and the poses of the contact points. parent_link_names = js.link.idxs_to_names( model=model, link_indices=jnp.array( model.kin_dyn_parameters.contact_parameters.body, dtype=int )[indices_of_enabled_collidable_points], ) L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[ indices_of_enabled_collidable_points ] # ===== # Tests # ===== # Load the model in ROD. rod_model = rod.Sdf.load(sdf=model.built_from).model # Add dummy frames on the contact points. for idx, link_name, L_p_C in zip( indices_of_enabled_collidable_points, parent_link_names, L_p_Ci, strict=True ): rod_model.add_frame( frame=rod.Frame( name=f"contact_point_{idx}", attached_to=link_name, pose=rod.Pose( relative_to=link_name, pose=jnp.zeros(shape=(6,)).at[0:3].set(L_p_C) ), ), ) # Rebuild the JaxSim model. model_with_frames = js.model.JaxSimModel.build_from_model_description( model_description=rod_model ) model_with_frames = js.model.reduce( model=model_with_frames, considered_joints=model.joint_names() ) # Rebuild the JaxSim data. data_with_frames = js.data.JaxSimModelData.build( model=model_with_frames, base_position=data.base_position, base_quaternion=data.base_orientation, joint_positions=data.joint_positions, base_linear_velocity=data.base_velocity[0:3], base_angular_velocity=data.base_velocity[3:6], joint_velocities=data.joint_velocities, velocity_representation=velocity_representation, ) # Extract the indexes of the frames attached to the contact points. frame_idxs = js.frame.names_to_idxs( model=model_with_frames, frame_names=( f"contact_point_{idx}" for idx in indices_of_enabled_collidable_points ), ) # Check that the number of frames is correct. assert len(frame_idxs) == len(parent_link_names) # Compute the contact Jacobian derivative. J̇_WC = js.contact.jacobian_derivative( model=model_with_frames, data=data_with_frames ) # Compute the contact Jacobian derivative using frames. J̇_WF = jax.vmap( js.frame.jacobian_derivative, in_axes=(None, None), )(model_with_frames, data, frame_index=frame_idxs) # Compare the two Jacobians. assert_allclose(J̇_WC, J̇_WF) ================================================ FILE: tests/test_api_data.py ================================================ import jax import jax.numpy as jnp import pytest from numpy.testing import assert_raises import jaxsim.api as js from jaxsim import VelRepr from jaxsim.utils import Mutability from . import utils from .utils import assert_allclose def test_data_valid( jaxsim_models_all: js.model.JaxSimModel, ): model = jaxsim_models_all data = js.data.JaxSimModelData.build(model=model) assert data.valid(model=model) def test_data_switch_velocity_representation( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=VelRepr.Inertial ) # ===== # Tests # ===== new_base_linear_velocity = jnp.array([1.0, -2.0, 3.0]) old_base_linear_velocity = data._base_linear_velocity # The following should not change the original `data` object since it raises. with pytest.raises(RuntimeError): with data.switch_velocity_representation( velocity_representation=VelRepr.Inertial ): with data.mutable_context(mutability=Mutability.MUTABLE): data._base_linear_velocity = new_base_linear_velocity raise RuntimeError("This is raised on purpose inside this context") assert_allclose(data._base_linear_velocity, old_base_linear_velocity) # The following instead should result to an updated `data` object. with ( data.switch_velocity_representation(velocity_representation=VelRepr.Inertial), data.mutable_context(mutability=Mutability.MUTABLE), ): data._base_linear_velocity = new_base_linear_velocity assert_allclose(data._base_linear_velocity, new_base_linear_velocity) def test_data_change_velocity_representation( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=VelRepr.Inertial ) # ===== # Tests # ===== kin_dyn_inertial = utils.build_kindyncomputations_from_jaxsim_model( model=model, data=data ) with data.switch_velocity_representation(VelRepr.Mixed): kin_dyn_mixed = utils.build_kindyncomputations_from_jaxsim_model( model=model, data=data ) with data.switch_velocity_representation(VelRepr.Body): kin_dyn_body = utils.build_kindyncomputations_from_jaxsim_model( model=model, data=data ) assert_allclose(data.base_velocity, kin_dyn_inertial.base_velocity()) if not model.floating_base(): return with data.switch_velocity_representation(VelRepr.Mixed): assert_allclose(data.base_velocity, kin_dyn_mixed.base_velocity()) assert_raises( AssertionError, assert_allclose, data.base_velocity[0:3], data._base_linear_velocity, ) assert_allclose(data.base_velocity[3:6], data._base_angular_velocity) with data.switch_velocity_representation(VelRepr.Body): assert_allclose(data.base_velocity, kin_dyn_body.base_velocity()) assert_raises( AssertionError, assert_allclose, data.base_velocity[0:3], data._base_linear_velocity, ) assert_raises( AssertionError, assert_allclose, data.base_velocity[3:6], data._base_angular_velocity, ) ================================================ FILE: tests/test_api_frame.py ================================================ import jax import jax.numpy as jnp import pytest from jax.errors import JaxRuntimeError from numpy.testing import assert_array_equal import jaxsim.api as js from jaxsim import VelRepr from jaxsim.math.quaternion import Quaternion from . import utils from .utils import assert_allclose def test_frame_index(jaxsim_models_types: js.model.JaxSimModel): model = jaxsim_models_types # ===== # Tests # ===== n_l = model.number_of_links() n_f = len(model.frame_names()) for idx, frame_name in enumerate(model.frame_names()): frame_index = n_l + idx assert js.frame.name_to_idx(model=model, frame_name=frame_name) == frame_index assert js.frame.idx_to_name(model=model, frame_index=frame_index) == frame_name assert ( js.frame.idx_of_parent_link(model=model, frame_index=frame_index) < model.number_of_links() ) # See discussion in https://github.com/gbionics/jaxsim/pull/280 assert_array_equal( js.frame.names_to_idxs(model=model, frame_names=model.frame_names()), jnp.arange(n_l, n_l + n_f), ) assert ( js.frame.idxs_to_names( model=model, frame_indices=tuple( js.frame.names_to_idxs( model=model, frame_names=model.frame_names() ).tolist() ), ) == model.frame_names() ) with pytest.raises(ValueError): _ = js.frame.name_to_idx(model=model, frame_name="non_existent_frame") with pytest.raises(JaxRuntimeError): _ = js.frame.idx_to_name(model=model, frame_index=-1) with pytest.raises(JaxRuntimeError): _ = js.frame.idx_to_name(model=model, frame_index=n_l - 1) with pytest.raises(JaxRuntimeError): _ = js.frame.idx_to_name(model=model, frame_index=n_l + n_f) with pytest.raises(JaxRuntimeError): _ = js.frame.idx_of_parent_link(model=model, frame_index=-1) with pytest.raises(JaxRuntimeError): _ = js.frame.idx_of_parent_link(model=model, frame_index=n_l - 1) with pytest.raises(JaxRuntimeError): _ = js.frame.idx_of_parent_link(model=model, frame_index=n_l + n_f) def test_frame_transforms( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=VelRepr.Inertial ) kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data) # Get all names of frames in the iDynTree model. frame_names = [ frame.name for frame in model.description.frames if frame.name in kin_dyn.frame_names() ] # Skip some entry of models with many frames. frame_names = [ name for name in frame_names if "skin" not in name or "laser" not in name or "depth" not in name ] # Get indices of frames. frame_indices = tuple( frame.index for frame in model.description.frames if frame.index is not None and frame.name in frame_names ) # ===== # Tests # ===== assert len(frame_indices) == len(frame_names) for frame_name in frame_names: W_H_F_js = js.frame.transform( model=model, data=data, frame_index=js.frame.name_to_idx(model=model, frame_name=frame_name), ) W_H_F_idt = kin_dyn.frame_transform(frame_name=frame_name) assert_allclose( W_H_F_js, W_H_F_idt, atol=1e-6, err_msg=f"Mismatch in {frame_name}" ) def test_frame_jacobians( jaxsim_models_types: js.model.JaxSimModel, velocity_representation: VelRepr, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=velocity_representation ) kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data) # Get all names of frames in the iDynTree model. frame_names = [ frame.name for frame in model.description.frames if frame.name in kin_dyn.frame_names() ] # Lower the number of frames for models with many frames. if model.name().lower() == "ergocub": assert any("sole" in name for name in frame_names) frame_names = [name for name in frame_names if "sole" in name] # Get indices of frames. frame_indices = tuple( frame.index for frame in model.description.frames if frame.index is not None and frame.name in frame_names ) # ===== # Tests # ===== assert len(frame_indices) == len(frame_names) for frame_name, frame_index in zip(frame_names, frame_indices, strict=True): J_WL_js = js.frame.jacobian(model=model, data=data, frame_index=frame_index) J_WL_idt = kin_dyn.jacobian_frame(frame_name=frame_name) assert_allclose(J_WL_js, J_WL_idt, err_msg=f"Mismatch in {frame_name}") for frame_name, frame_index in zip(frame_names, frame_indices, strict=True): v_WF_idt = kin_dyn.frame_velocity(frame_name=frame_name) v_WF_js = js.frame.velocity(model=model, data=data, frame_index=frame_index) assert_allclose(v_WF_js, v_WF_idt, err_msg=f"Mismatch in {frame_name}") def test_frame_jacobian_derivative( jaxsim_models_types: js.model.JaxSimModel, velocity_representation: VelRepr, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=velocity_representation ) kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data) # Get all names of frames in the iDynTree model. frame_names = [ frame.name for frame in model.description.frames if frame.name in kin_dyn.frame_names() ] # Skip some entry of models with many frames. frame_names = [ name for name in frame_names if "skin" not in name or "laser" not in name or "depth" not in name ] frame_idxs = js.frame.names_to_idxs(model=model, frame_names=tuple(frame_names)) # =============== # Test against AD # =============== # Get the generalized velocity. I_ν = data.generalized_velocity # Compute J̇. O_J̇_WF_I = jax.vmap( lambda frame_index: js.frame.jacobian_derivative( model=model, data=data, frame_index=frame_index ) )(frame_idxs) assert O_J̇_WF_I.shape == (len(frame_names), 6, 6 + model.dofs()) # Compute the plain Jacobian. # This function will be used to compute the Jacobian derivative with AD. def J(q, frame_idxs) -> jax.Array: data_ad = js.data.JaxSimModelData.build( model=model, velocity_representation=data.velocity_representation, base_position=q[:3], base_quaternion=q[3:7], joint_positions=q[7:], ) O_J_ad_WF_I = jax.vmap( lambda model, data, frame_index: js.frame.jacobian( model=model, data=data, frame_index=frame_index ), in_axes=(None, None, 0), )(model, data_ad, frame_idxs) return O_J_ad_WF_I def compute_q(data: js.data.JaxSimModelData) -> jax.Array: q = jnp.hstack( [ data.base_position, data.base_orientation, data.joint_positions, ] ) return q def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: with data.switch_velocity_representation(VelRepr.Body): B_ω_WB = data.base_velocity[3:6] with data.switch_velocity_representation(VelRepr.Mixed): W_ṗ_B = data.base_velocity[0:3] W_Q̇_B = Quaternion.derivative( quaternion=data.base_orientation, omega=B_ω_WB, omega_in_body_fixed=True, K=0.0, ).squeeze() q̇ = jnp.hstack([W_ṗ_B, W_Q̇_B, data.joint_velocities]) return q̇ # Compute q and q̇. q = compute_q(data) q̇ = compute_q̇(data) # Compute dJ/dt with AD. dJ_dq = jax.jacfwd(J, argnums=0)(q, frame_idxs) O_J̇_ad_WF_I = jnp.einsum("ijkq,q->ijk", dJ_dq, q̇) assert_allclose(O_J̇_WF_I, O_J̇_ad_WF_I) # ===================== # Test against iDynTree # ===================== # Compute the product J̇ν. O_a_bias_WF = jax.vmap( lambda O_J̇_WF_I, I_ν: O_J̇_WF_I @ I_ν, in_axes=(0, None), )(O_J̇_WF_I, I_ν) # Compare the two computations. for index, name in enumerate(frame_names): J̇ν_idt = kin_dyn.frame_bias_acc(frame_name=name) J̇ν_js = O_a_bias_WF[index] assert_allclose(J̇ν_js, J̇ν_idt, err_msg=f"Mismatch in {name}") ================================================ FILE: tests/test_api_joint.py ================================================ import jax.numpy as jnp import pytest from jax.errors import JaxRuntimeError from numpy.testing import assert_array_equal import jaxsim.api as js def test_joint_index( jaxsim_models_types: js.model.JaxSimModel, ): model = jaxsim_models_types # ===== # Tests # ===== for idx, joint_name in enumerate(model.joint_names()): assert js.joint.name_to_idx(model=model, joint_name=joint_name) == idx assert js.joint.idx_to_name(model=model, joint_index=idx) == joint_name # See discussion in https://github.com/gbionics/jaxsim/pull/280 assert_array_equal( js.joint.names_to_idxs(model=model, joint_names=model.joint_names()), jnp.arange(model.number_of_joints()), ) assert ( js.joint.idxs_to_names( model=model, joint_indices=tuple( js.joint.names_to_idxs( model=model, joint_names=model.joint_names() ).tolist() ), ) == model.joint_names() ) with pytest.raises(ValueError): _ = js.joint.name_to_idx(model=model, joint_name="non_existent_joint") with pytest.raises(JaxRuntimeError): _ = js.joint.idx_to_name(model=model, joint_index=-1) with pytest.raises(IndexError): _ = js.joint.idx_to_name(model=model, joint_index=model.number_of_joints()) ================================================ FILE: tests/test_api_link.py ================================================ import jax import jax.numpy as jnp import pytest from jax.errors import JaxRuntimeError from numpy.testing import assert_array_equal import jaxsim.api as js import jaxsim.math from jaxsim import VelRepr from . import utils from .utils import assert_allclose def test_link_index( jaxsim_models_types: js.model.JaxSimModel, ): model = jaxsim_models_types # ===== # Tests # ===== for idx, link_name in enumerate(model.link_names()): assert js.link.name_to_idx(model=model, link_name=link_name) == idx assert js.link.idx_to_name(model=model, link_index=idx) == link_name # See discussion in https://github.com/gbionics/jaxsim/pull/280 assert_array_equal( js.link.names_to_idxs(model=model, link_names=model.link_names()), jnp.arange(model.number_of_links()), ) assert ( js.link.idxs_to_names( model=model, link_indices=tuple( js.link.names_to_idxs( model=model, link_names=model.link_names() ).tolist() ), ) == model.link_names() ) with pytest.raises(ValueError): _ = js.link.name_to_idx(model=model, link_name="non_existent_link") with pytest.raises(JaxRuntimeError): _ = js.link.idx_to_name(model=model, link_index=-1) with pytest.raises(IndexError): _ = js.link.idx_to_name(model=model, link_index=model.number_of_links()) def test_link_inertial_properties( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=VelRepr.Inertial, ) kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data) # ===== # Tests # ===== for link_name, link_idx in zip( model.link_names(), jnp.arange(model.number_of_links()), strict=True, ): if link_name == model.base_link(): continue assert_allclose( js.link.mass(model=model, link_index=link_idx), kin_dyn.link_mass(link_name=link_name), err_msg=f"Mismatch in {link_name}", ) assert_allclose( js.link.spatial_inertia(model=model, link_index=link_idx), kin_dyn.link_spatial_inertia(link_name=link_name), err_msg=f"Mismatch in {link_name}", ) def test_link_transforms( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=VelRepr.Inertial, ) kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data) # ===== # Tests # ===== W_H_LL_model = data._link_transforms W_H_LL_links = jax.vmap( lambda idx: js.link.transform(model=model, data=data, link_index=idx) )(jnp.arange(model.number_of_links())) assert_allclose(W_H_LL_model, W_H_LL_links) for W_H_L, link_name in zip(W_H_LL_links, model.link_names(), strict=True): assert_allclose( W_H_L, kin_dyn.frame_transform(frame_name=link_name), err_msg=f"Mismatch in {link_name}", ) def test_link_jacobians( jaxsim_models_types: js.model.JaxSimModel, velocity_representation: VelRepr, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=velocity_representation, ) kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data) # ===== # Tests # ===== J_WL_links = jax.vmap( lambda idx: js.link.jacobian(model=model, data=data, link_index=idx) )(jnp.arange(model.number_of_links())) for J_WL, link_name in zip(J_WL_links, model.link_names(), strict=True): assert_allclose( J_WL, kin_dyn.jacobian_frame(frame_name=link_name), err_msg=f"Mismatch in {link_name}", ) # The following is true only in inertial-fixed representation. J_WL_model = js.model.generalized_free_floating_jacobian(model=model, data=data) assert_allclose(J_WL_model, J_WL_links) for link_name, link_idx in zip( model.link_names(), jnp.arange(model.number_of_links()), strict=True, ): v_WL_idt = kin_dyn.frame_velocity(frame_name=link_name) v_WL_js = js.link.velocity(model=model, data=data, link_index=link_idx) assert_allclose(v_WL_js, v_WL_idt, err_msg=f"Mismatch in {link_name}") # Test conversion to a different output velocity representation. for other_repr in {VelRepr.Inertial, VelRepr.Body, VelRepr.Mixed}.difference( {data.velocity_representation} ): with data.switch_velocity_representation(other_repr): kin_dyn_other_repr = utils.build_kindyncomputations_from_jaxsim_model( model=model, data=data ) for link_name, link_idx in zip( model.link_names(), jnp.arange(model.number_of_links()), strict=True, ): v_WL_idt = kin_dyn_other_repr.frame_velocity(frame_name=link_name) v_WL_js = js.link.velocity( model=model, data=data, link_index=link_idx, output_vel_repr=other_repr ) assert_allclose(v_WL_js, v_WL_idt, err_msg=f"Mismatch in {link_name}") def test_link_bias_acceleration( jaxsim_models_types: js.model.JaxSimModel, velocity_representation: VelRepr, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=velocity_representation, ) kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data) # ===== # Tests # ===== for name, index in zip( model.link_names(), jnp.arange(model.number_of_links()), strict=True, ): Jν_idt = kin_dyn.frame_bias_acc(frame_name=name) Jν_js = js.link.bias_acceleration(model=model, data=data, link_index=index) assert_allclose(Jν_js, Jν_idt, err_msg=f"Mismatch in {name}") # Test that the conversion of the link bias acceleration works as expected. match data.velocity_representation: # We exclude the mixed representation because converting the acceleration is # more complex than using the plain 6D transform matrix. case VelRepr.Mixed: pass # Inertial-fixed to body-fixed conversion. case VelRepr.Inertial: W_H_L = data._link_transforms W_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data) with data.switch_velocity_representation(VelRepr.Body): W_X_L = jax.vmap( lambda W_H_L: jaxsim.math.Adjoint.from_transform(transform=W_H_L) )(W_H_L) L_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data) W_a_bias_WL_converted = jax.vmap( lambda W_X_L, L_a_bias_WL: W_X_L @ L_a_bias_WL )(W_X_L, L_a_bias_WL) assert_allclose(W_a_bias_WL, W_a_bias_WL_converted) # Body-fixed to inertial-fixed conversion. case VelRepr.Body: W_H_L = data._link_transforms L_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data) with data.switch_velocity_representation(VelRepr.Inertial): L_X_W = jax.vmap( lambda W_H_L: jaxsim.math.Adjoint.from_transform( transform=W_H_L, inverse=True ) )(W_H_L) W_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data) L_a_bias_WL_converted = jax.vmap( lambda L_X_W, W_a_bias_WL: L_X_W @ W_a_bias_WL )(L_X_W, W_a_bias_WL) assert_allclose(L_a_bias_WL, L_a_bias_WL_converted) def test_link_jacobian_derivative( jaxsim_models_types: js.model.JaxSimModel, velocity_representation: VelRepr, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=velocity_representation, ) # ===== # Tests # ===== # Get the generalized velocity. I_ν = data.generalized_velocity # Compute J̇. O_J̇_WL_I = jax.vmap( lambda link_index: js.link.jacobian_derivative( model=model, data=data, link_index=link_index ) )(jnp.arange(model.number_of_links())) # Compute the product J̇ν. O_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data) # Compare the two computations. assert_allclose(jnp.einsum("l6g,g->l6", O_J̇_WL_I, I_ν), O_a_bias_WL) # Compute the plain Jacobian. # This function will be used to compute the Jacobian derivative with AD. # Given q, computing J̇ by AD-ing this function should work out-of-the-box with # all velocity representations, that are handled internally when computing J. def J(q) -> jax.Array: data_ad = js.data.JaxSimModelData.build( model=model, velocity_representation=data.velocity_representation, base_position=q[:3], base_quaternion=q[3:7], joint_positions=q[7:], ) O_J_WL_I = js.model.generalized_free_floating_jacobian( model=model, data=data_ad ) return O_J_WL_I def compute_q(data: js.data.JaxSimModelData) -> jax.Array: q = jnp.hstack( [data.base_position, data.base_orientation, data.joint_positions] ) return q def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: with data.switch_velocity_representation(VelRepr.Body): B_ω_WB = data.base_velocity[3:6] with data.switch_velocity_representation(VelRepr.Mixed): W_ṗ_B = data.base_velocity[0:3] W_Q̇_B = jaxsim.math.Quaternion.derivative( quaternion=data.base_orientation, omega=B_ω_WB, omega_in_body_fixed=True, K=0.0, ).squeeze() q̇ = jnp.hstack([W_ṗ_B, W_Q̇_B, data.joint_velocities]) return q̇ # Compute q and q̇. q = compute_q(data) q̇ = compute_q̇(data) # Compute dJ/dt with AD. dJ_dq = jax.jacfwd(J, argnums=0)(q) O_J̇_ad_WL_I = jnp.einsum("ijkq,q->ijk", dJ_dq, q̇) assert_allclose(O_J̇_WL_I, O_J̇_ad_WL_I) assert_allclose( jnp.einsum("l6g,g->l6", O_J̇_ad_WL_I, I_ν), jnp.einsum("l6g,g->l6", O_J̇_WL_I, I_ν), ) ================================================ FILE: tests/test_api_model.py ================================================ import pathlib import jax import jax.numpy as jnp import numpy as np import rod import jaxsim.api as js import jaxsim.math from jaxsim import VelRepr from . import utils from .utils import assert_allclose def test_model_creation_and_reduction( jaxsim_model_ergocub: js.model.JaxSimModel, prng_key: jax.Array, ): model_full = jaxsim_model_ergocub _, subkey = jax.random.split(prng_key, num=2) data_full = js.data.random_model_data( model=model_full, key=subkey, velocity_representation=VelRepr.Inertial, base_pos_bounds=((0, 0, 0.8), (0, 0, 0.8)), ) # ===== # Tests # ===== # Check that the data of the full model is valid. assert data_full.valid(model=model_full) # Build the ROD model from the original description. assert isinstance(model_full.built_from, str | pathlib.Path) rod_sdf = rod.Sdf.load(sdf=model_full.built_from) assert len(rod_sdf.models()) == 1 # Get all non-fixed joint names from the description. joint_names_in_description = [ j.name for j in rod_sdf.models()[0].joints() if j.type != "fixed" ] # Check that all non-fixed joints are in the full model. assert set(joint_names_in_description) == set(model_full.joint_names()) # ================ # Reduce the model # ================ # Get the names of the joints to keep in the reduced model. reduced_joints = tuple( j for j in model_full.joint_names() if "camera" not in j and "neck" not in j and "wrist" not in j and "thumb" not in j and "index" not in j and "middle" not in j and "ring" not in j and "pinkie" not in j # and "elbow" not in j and "shoulder" not in j and "torso" not in j and "r_knee" not in j ) # Reduce the model. # Note: here we also specify a non-zero position of the removed joints. # The process should take into account the corresponding joint transforms # when the link-joint-link chains are lumped together. model_reduced = js.model.reduce( model=model_full, considered_joints=reduced_joints, locked_joint_positions=dict( zip( model_full.joint_names(), data_full.joint_positions.tolist(), strict=True, ) ), ) # Check DoFs. assert model_full.dofs() != model_reduced.dofs() # Check that all non-fixed joints are in the reduced model. assert set(reduced_joints) == set(model_reduced.joint_names()) # Check that the reduced model maintains the same terrain of the full model. assert model_full.terrain == model_reduced.terrain # Check that the reduced model maintains the same contact model of the full model. assert model_full.contact_model == model_reduced.contact_model # Check that the reduced model maintains the same integration step of the full model. assert model_full.time_step == model_reduced.time_step joint_idxs = js.joint.names_to_idxs( model=model_full, joint_names=model_reduced.joint_names() ) # Build the data of the reduced model. data_reduced = js.data.JaxSimModelData.build( model=model_reduced, base_position=data_full.base_position, base_quaternion=data_full.base_orientation, joint_positions=data_full.joint_positions[joint_idxs], base_linear_velocity=data_full.base_velocity[0:3], base_angular_velocity=data_full.base_velocity[3:6], joint_velocities=data_full.joint_velocities[joint_idxs], velocity_representation=data_full.velocity_representation, ) # Check that the reduced model data is valid. assert not data_reduced.valid(model=model_full) assert data_reduced.valid(model=model_reduced) # Check that the total mass is preserved. assert_allclose( js.model.total_mass(model=model_full), js.model.total_mass(model=model_reduced) ) # Check that the CoM position is preserved. assert_allclose( js.com.com_position(model=model_full, data=data_full), js.com.com_position(model=model_reduced, data=data_reduced), atol=1e-6, ) # Check that joint serialization works. assert_allclose(data_full.joint_positions[joint_idxs], data_reduced.joint_positions) assert_allclose( data_full.joint_velocities[joint_idxs], data_reduced.joint_velocities ) # Check that link transforms are preserved. for link_name in model_reduced.link_names(): W_H_L_full = js.link.transform( model=model_full, data=data_full, link_index=js.link.name_to_idx(model=model_full, link_name=link_name), ) W_H_L_reduced = js.link.transform( model=model_reduced, data=data_reduced, link_index=js.link.name_to_idx(model=model_reduced, link_name=link_name), ) assert_allclose(W_H_L_full, W_H_L_reduced) # Check that collidable point positions are preserved. assert_allclose( js.contact.collidable_point_positions(model=model_full, data=data_full), js.contact.collidable_point_positions(model=model_reduced, data=data_reduced), ) # ===================== # Test against iDynTree # ===================== kin_dyn_full = utils.build_kindyncomputations_from_jaxsim_model( model=model_full, data=data_full ) kin_dyn_reduced = utils.build_kindyncomputations_from_jaxsim_model( model=model_reduced, data=data_reduced ) # Check that the total mass is preserved. assert_allclose(kin_dyn_full.total_mass(), kin_dyn_reduced.total_mass()) # Check that the CoM position match. assert_allclose(kin_dyn_full.com_position(), kin_dyn_reduced.com_position()) assert_allclose( kin_dyn_full.com_position(), js.com.com_position(model=model_reduced, data=data_reduced), ) # Check that link transforms match. for link_name in model_reduced.link_names(): assert_allclose( kin_dyn_reduced.frame_transform(frame_name=link_name), kin_dyn_full.frame_transform(frame_name=link_name), err_msg=f"Mismatch in link {link_name}", ) assert_allclose( kin_dyn_reduced.frame_transform(frame_name=link_name), js.link.transform( model=model_reduced, data=data_reduced, link_index=js.link.name_to_idx( model=model_reduced, link_name=link_name ), ), err_msg=f"Mismatch in link {link_name}", ) # Check that frame transforms match. for frame_name in model_reduced.frame_names(): if frame_name not in kin_dyn_reduced.frame_names(): continue # Skip some entry of models with many frames. if "skin" in frame_name or "laser" in frame_name or "depth" in frame_name: continue assert_allclose( kin_dyn_reduced.frame_transform(frame_name=frame_name), kin_dyn_full.frame_transform(frame_name=frame_name), err_msg=f"Mismatch in frame {frame_name}", ) assert_allclose( kin_dyn_reduced.frame_transform(frame_name=frame_name), js.frame.transform( model=model_reduced, data=data_reduced, frame_index=js.frame.name_to_idx( model=model_reduced, frame_name=frame_name ), ), err_msg=f"Mismatch in frame {frame_name}", ) def test_model_properties( jaxsim_models_types: js.model.JaxSimModel, velocity_representation: VelRepr, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=velocity_representation ) kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data) # ===== # Tests # ===== m_idt = kin_dyn.total_mass() m_js = js.model.total_mass(model=model) assert_allclose(m_idt, m_js) J_Bh_idt = kin_dyn.total_momentum_jacobian() J_Bh_js = js.model.total_momentum_jacobian(model=model, data=data) assert_allclose(J_Bh_idt, J_Bh_js) h_tot_idt = kin_dyn.total_momentum() h_tot_js = js.model.total_momentum(model=model, data=data) assert_allclose(h_tot_idt, h_tot_js) M_locked_idt = kin_dyn.locked_spatial_inertia() M_locked_js = js.model.locked_spatial_inertia(model=model, data=data) assert_allclose(M_locked_idt, M_locked_js) J_avg_idt = kin_dyn.average_velocity_jacobian() J_avg_js = js.model.average_velocity_jacobian(model=model, data=data) assert_allclose(J_avg_idt, J_avg_js) v_avg_idt = kin_dyn.average_velocity() v_avg_js = js.model.average_velocity(model=model, data=data) assert_allclose(v_avg_idt, v_avg_js) def test_model_rbda( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, velocity_representation: VelRepr, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=velocity_representation ) kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data) # ===== # Tests # ===== # Support both fixed-base and floating-base models by slicing the first six rows. sl = np.s_[0:] if model.floating_base() else np.s_[6:] # Mass matrix M_idt = kin_dyn.mass_matrix() M_js = js.model.free_floating_mass_matrix(model=model, data=data) assert_allclose(M_idt[sl, sl], M_js[sl, sl]) # Gravity forces g_idt = kin_dyn.gravity_forces() g_js = js.model.free_floating_gravity_forces(model=model, data=data) assert_allclose(g_idt[sl], g_js[sl]) # Bias forces h_idt = kin_dyn.bias_forces() h_js = js.model.free_floating_bias_forces(model=model, data=data) assert_allclose(h_idt[sl], h_js[sl]) # Forward kinematics HH_js = data._link_transforms HH_idt = jnp.stack( [kin_dyn.frame_transform(frame_name=name) for name in model.link_names()] ) assert_allclose(HH_idt, HH_js) # Bias accelerations Jν_js = js.model.link_bias_accelerations(model=model, data=data) Jν_idt = jnp.stack( [kin_dyn.frame_bias_acc(frame_name=name) for name in model.link_names()] ) assert_allclose(Jν_idt, Jν_js) # Mass matrix inverse via RBDA M_inv_js = js.model.free_floating_mass_matrix_inverse(model=model, data=data) M_inv_idt = jnp.linalg.inv(M_idt) assert_allclose(M_inv_idt[sl, sl], M_inv_js[sl, sl]) def test_model_jacobian( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, ): model = jaxsim_models_types key, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=VelRepr.Inertial ) # ===== # Tests # ===== # Create random references (joint torques and link forces) _, subkey1, subkey2 = jax.random.split(key, num=3) references = js.references.JaxSimModelReferences.build( model=model, joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)), link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)), data=data, velocity_representation=data.velocity_representation, ) # Remove the force applied to the base link if the model is fixed-base if not model.floating_base(): references = references.apply_link_forces( forces=jnp.atleast_2d(jnp.zeros(6)), model=model, data=data, link_names=(model.base_link(),), additive=False, ) # Get the J.T @ f product in inertial-fixed input/output representation. # We use doubly right-trivialized jacobian with inertial-fixed 6D forces. with ( references.switch_velocity_representation(VelRepr.Inertial), data.switch_velocity_representation(VelRepr.Inertial), ): f = references.link_forces(model=model, data=data) assert_allclose(f, references._link_forces) J = js.model.generalized_free_floating_jacobian(model=model, data=data) JTf_inertial = jnp.einsum("l6g,l6->g", J, f) for vel_repr in (VelRepr.Body, VelRepr.Mixed): with references.switch_velocity_representation(vel_repr): # Get the jacobian having an inertial-fixed input representation (so that # it computes the same quantity computed above) and an output representation # compatible with the frame in which the external forces are expressed. with data.switch_velocity_representation(VelRepr.Inertial): J = js.model.generalized_free_floating_jacobian( model=model, data=data, output_vel_repr=vel_repr ) # Get the forces in the tested representation and compute the product # O_J_WL_W.T @ O_f, producing a generalized acceleration in W. # The resulting acceleration can be tested again the one computed before. with data.switch_velocity_representation(vel_repr): f = references.link_forces(model=model, data=data) JTf_other = jnp.einsum("l6g,l6->g", J, f) assert_allclose(JTf_inertial, JTf_other, err_msg=vel_repr.name) def test_coriolis_matrix( jaxsim_models_types: js.model.JaxSimModel, velocity_representation: VelRepr, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=velocity_representation ) # ===== # Tests # ===== I_ν = data.generalized_velocity C = js.model.free_floating_coriolis_matrix(model=model, data=data) h = js.model.free_floating_bias_forces(model=model, data=data) g = js.model.free_floating_gravity_forces(model=model, data=data) Cν = h - g assert_allclose(C @ I_ν, Cν) # Compute the free-floating mass matrix. # This function will be used to compute the Ṁ with AD. # Given q, computing Ṁ by AD-ing this function should work out-of-the-box with # all velocity representations, that are handled internally when computing M. def M(q) -> jax.Array: data_ad = js.data.JaxSimModelData.build( model=model, velocity_representation=data.velocity_representation, base_position=q[:3], base_quaternion=q[3:7], joint_positions=q[7:], ) M = js.model.free_floating_mass_matrix(model=model, data=data_ad) return M def compute_q(data: js.data.JaxSimModelData) -> jax.Array: q = jnp.hstack( [data.base_position, data.base_orientation, data.joint_positions] ) return q def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: with data.switch_velocity_representation(VelRepr.Body): B_ω_WB = data.base_velocity[3:6] with data.switch_velocity_representation(VelRepr.Mixed): W_ṗ_B = data.base_velocity[0:3] W_Q̇_B = jaxsim.math.Quaternion.derivative( quaternion=data.base_orientation, omega=B_ω_WB, omega_in_body_fixed=True, K=0.0, ).squeeze() q̇ = jnp.hstack([W_ṗ_B, W_Q̇_B, data.joint_velocities]) return q̇ # Compute q and q̇. q = compute_q(data) q̇ = compute_q̇(data) # Compute Ṁ with AD. dM_dq = jax.jacfwd(M, argnums=0)(q) Ṁ = jnp.einsum("ijq,q->ij", dM_dq, q̇) # We need to zero the blocks projecting joint variables to the base configuration # for fixed-base models. if not model.floating_base(): Ṁ = Ṁ.at[0:6, 6:].set(0) Ṁ = Ṁ.at[6:, 0:6].set(0) # Ensure that (Ṁ - 2C) is skew symmetric. assert_allclose(Ṁ - C - C.T, 0.0) def test_model_fd_id_consistency( jaxsim_models_types: js.model.JaxSimModel, velocity_representation: VelRepr, prng_key: jax.Array, ): model = jaxsim_models_types key, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=velocity_representation ) # ===== # Tests # ===== # Create random references (joint torques and link forces). _, subkey1, subkey2 = jax.random.split(key, num=3) references = js.references.JaxSimModelReferences.build( model=model, joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)), link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)), data=data, velocity_representation=data.velocity_representation, ) # Remove the force applied to the base link if the model is fixed-base. if not model.floating_base(): references = references.apply_link_forces( forces=jnp.atleast_2d(jnp.zeros(6)), model=model, data=data, link_names=(model.base_link(),), additive=False, ) # Compute forward dynamics with ABA. v̇_WB_aba, s̈_aba = js.model.forward_dynamics_aba( model=model, data=data, joint_forces=references.joint_force_references(), link_forces=references.link_forces(model=model, data=data), ) # Compute forward dynamics with CRB. v̇_WB_crb, s̈_crb = js.model.forward_dynamics_crb( model=model, data=data, joint_forces=references.joint_force_references(), link_forces=references.link_forces(model=model, data=data), ) assert_allclose(s̈_aba, s̈_crb) assert_allclose(v̇_WB_aba, v̇_WB_crb) # Compute inverse dynamics with the quantities computed by forward dynamics fB_id, τ_id = js.model.inverse_dynamics( model=model, data=data, joint_accelerations=s̈_aba, base_acceleration=v̇_WB_aba, link_forces=references.link_forces(model=model, data=data), ) # Check consistency between FD and ID assert_allclose(τ_id, references.joint_force_references(model=model)) assert_allclose(fB_id, 0.0) if model.floating_base(): # If we remove the base 6D force from the inputs, we should find it as output. fB_id, τ_id = js.model.inverse_dynamics( model=model, data=data, joint_accelerations=s̈_aba, base_acceleration=v̇_WB_aba, link_forces=references.link_forces(model=model, data=data) .at[0] .set(jnp.zeros(6)), ) assert_allclose(τ_id, references.joint_force_references(model=model)) assert_allclose(fB_id, references.link_forces(model=model, data=data)[0]) def test_aba_vs_aba_parallel( jaxsim_models_all: js.model.JaxSimModel, velocity_representation: VelRepr, prng_key: jax.Array, ): """ Verify that the level-parallel ABA produces identical results to the sequential ABA, both at the low-level RBDA and at the high-level model API. """ model = jaxsim_models_all key, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=velocity_representation ) # Create random references. _, subkey1, subkey2 = jax.random.split(key, num=3) references = js.references.JaxSimModelReferences.build( model=model, joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)), link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)), data=data, velocity_representation=data.velocity_representation, ) if not model.floating_base(): references = references.apply_link_forces( forces=jnp.atleast_2d(jnp.zeros(6)), model=model, data=data, link_names=(model.base_link(),), additive=False, ) joint_forces = references.joint_force_references() link_forces = references.link_forces(model=model, data=data) v̇_WB_seq, s̈_seq = js.model.forward_dynamics_aba( model=model, data=data, joint_forces=joint_forces, link_forces=link_forces, parallel=False, ) v̇_WB_par, s̈_par = js.model.forward_dynamics_aba( model=model, data=data, joint_forces=joint_forces, link_forces=link_forces, parallel=True, ) assert_allclose(v̇_WB_seq, v̇_WB_par, atol=1e-9) assert_allclose(s̈_seq, s̈_par, atol=1e-9) def test_fk_vs_fk_parallel( jaxsim_models_all: js.model.JaxSimModel, velocity_representation: VelRepr, prng_key: jax.Array, ): """ Verify that the level-parallel FK produces identical results to the sequential FK. """ model = jaxsim_models_all _, subkey = jax.random.split(prng_key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=velocity_representation ) W_H_seq = js.model.forward_kinematics(model=model, data=data, parallel=False) W_H_par = js.model.forward_kinematics(model=model, data=data, parallel=True) assert_allclose(W_H_seq, W_H_par, atol=1e-9) ================================================ FILE: tests/test_api_model_hw_parametrization.py ================================================ import pathlib import xml.etree.ElementTree as ET import jax import jax.numpy as jnp import numpy as np import pytest import rod import jaxsim.api as js from jaxsim.api.kin_dyn_parameters import ( HwLinkMetadata, LinkParametrizableShape, ScalingFactors, ) from jaxsim.rbda.contacts import SoftContactsParams from .utils import assert_allclose def test_update_hw_link_parameters(jaxsim_model_garpez: js.model.JaxSimModel): """ Test that the hardware parameters of the model are updated correctly. """ model = jaxsim_model_garpez # Store initial hardware parameters initial_metadata = model.kin_dyn_parameters.hw_link_metadata # Create the scaling factors scaling_parameters = ScalingFactors( dims=jnp.array( [ [2.0, 1.5, 1.0], # Scale x, y, z for link1 [1.2, 1.0, 1.0], # Scale r for link2 [1.5, 0.8, 1.0], # Scale r, l for link3 [1.5, 1.0, 0.8], # Scale x, y, z for link4 ] ), density=jnp.ones(4), ) # Update the model using the scaling factors updated_model = js.model.update_hw_parameters(model, scaling_parameters) # Compare updated hardware parameters for link_idx, link_name in enumerate(model.link_names()): updated_metadata = jax.tree.map( lambda x, link_idx=link_idx: x[link_idx], updated_model.kin_dyn_parameters.hw_link_metadata, ) initial_metadata_link = jax.tree.map( lambda x, link_idx=link_idx: x[link_idx], initial_metadata ) # TODO: Compute the 3D scaling vector # scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector( # initial_metadata_link.shape, scaling_parameters.dims[link_idx] # ) expected_link_dimensions = ( initial_metadata_link.geometry * scaling_parameters.dims[link_idx] ) # Compare shape dimensions assert_allclose( updated_metadata.geometry, expected_link_dimensions, atol=1e-6, err_msg=f"Mismatch in dimensions for link {link_name}", ) @pytest.mark.parametrize( "jaxsim_model_garpez_scaled", [ { "link1_scale": 4.0, "link2_scale": 3.0, "link3_scale": 2.0, "link4_scale": 1.5, } ], indirect=True, ) def test_model_scaling_against_rod( jaxsim_model_garpez: js.model.JaxSimModel, jaxsim_model_garpez_scaled: js.model.JaxSimModel, ): """ Test that scaling the HW parameters of JaxSim model matches the kin/dyn quantities of a JaxSim model obtained from a pre-scaled rod model. """ # Define scaling parameters scaling_parameters = ScalingFactors( dims=jnp.array( [ [4.0, 1.0, 1.0], # Scale only x-dimension for link1 [3.0, 1.0, 1.0], # Scale only r-dimension for link2 [1.0, 2.0, 1.0], # Scale l dimension for link3 [1.5, 1.0, 1.0], # Scale only x-dimension for link4 ] ), density=jnp.ones(4), ) # Apply scaling to the original JaxSim model updated_model = js.model.update_hw_parameters( jaxsim_model_garpez, scaling_parameters ) # Compare hardware parameters of the scaled JaxSim model with the pre-scaled JaxSim model scaled_metadata = updated_model.kin_dyn_parameters.hw_link_metadata pre_scaled_metadata = jaxsim_model_garpez_scaled.kin_dyn_parameters.hw_link_metadata # Compare shape dimensions assert_allclose(scaled_metadata.geometry, pre_scaled_metadata.geometry, atol=1e-6) # Compare mass scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia(scaled_metadata) pre_scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia(pre_scaled_metadata) assert_allclose(scaled_mass, pre_scaled_mass, atol=1e-6) # Compare inertia tensors _, scaled_inertia = HwLinkMetadata.compute_mass_and_inertia(scaled_metadata) _, pre_scaled_inertia = HwLinkMetadata.compute_mass_and_inertia(pre_scaled_metadata) assert_allclose(scaled_inertia, pre_scaled_inertia, atol=1e-6) # Compare transformations assert_allclose(scaled_metadata.L_H_G, pre_scaled_metadata.L_H_G, atol=1e-6) assert_allclose(scaled_metadata.L_H_vis, pre_scaled_metadata.L_H_vis, atol=1e-6) # Compare collidable points positions assert_allclose( jaxsim_model_garpez_scaled.kin_dyn_parameters.contact_parameters.point, updated_model.kin_dyn_parameters.contact_parameters.point, atol=1e-6, ) def test_update_hw_parameters_vmap( jaxsim_model_garpez: js.model.JaxSimModel, ): """ Test that the hardware parameters of the model are updated correctly using vmap to create a set of n updated models. """ model_nominal = jaxsim_model_garpez dofs = model_nominal.dofs() # Define a set of scaling factors for n models n = 10 # Number of updated models to create scaling_factors = [ ScalingFactors( dims=(scale * jnp.ones((model_nominal.number_of_links(), 3))), density=(scale * jnp.ones(model_nominal.number_of_links())), ) for scale in jnp.linspace(2.0, 2.0 + n - 1, n) ] # Convert the list of ScalingFactors to a JAX array of pytrees scaling_factors = jax.tree.map(lambda *l: jnp.stack(l), *scaling_factors) # Generate a batch of updated models using vmap updated_models = jax.vmap(js.model.update_hw_parameters, in_axes=(None, 0))( model_nominal, scaling_factors, ) def validate_model(updated_model): assert updated_model is not None # Compute forward kinematics for the "link3" link H_link3 = js.link.transform( model=updated_model, data=js.data.JaxSimModelData.build(model=updated_model), link_index=js.link.name_to_idx(model=updated_model, link_name="link3"), ) # Compute the mass matrix M = js.model.free_floating_mass_matrix( model=updated_model, data=js.data.JaxSimModelData.build(model=updated_model), ) assert H_link3 is not None assert H_link3.shape == (4, 4) assert M is not None assert isinstance(M, jnp.ndarray) assert M.shape == (6 + dofs, 6 + dofs) # Use vmap to validate all updated models jax.vmap(validate_model)(updated_models) @pytest.mark.parametrize( "jaxsim_model_garpez_scaled", [ { "link1_scale": 4.0, "link2_scale": 3.0, "link3_scale": 2.0, "link4_scale": 1.5, } ], indirect=True, ) def test_export_updated_model( jaxsim_model_garpez: js.model.JaxSimModel, jaxsim_model_garpez_scaled: js.model.JaxSimModel, ): """ Test the export of an updated model using JaxSimModel.export_updated_model. """ model: js.model.JaxSimModel = jaxsim_model_garpez # Define scaling parameters scaling_parameters = ScalingFactors( dims=jnp.array( [ [4.0, 1.0, 1.0], # Scale x-dimension for link1 [3.0, 1.0, 1.0], # Scale r-dimension for link2 [1.0, 2.0, 1.0], # Scale l-dimension for link3 [1.5, 1.0, 1.0], # Scale x-dimension for link4 ] ), density=jnp.ones(4), ) identity_scaling = ScalingFactors( dims=jnp.ones((model.number_of_links(), 3)), density=jnp.ones(model.number_of_links()), ) def get_link_by_name(model, name): try: return next(link for link in model.links() if link.name == name) except StopIteration as err: raise ValueError( f"Link '{name}' not found. Available links: {[l.name for l in model.links()]}" ) from err def compare_geometries(exported_link, ref_link, label=""): exported_geom = exported_link.visual.geometry.geometry() ref_geom = ref_link.visual.geometry.geometry() attrs = [attr for attr in vars(exported_geom) if hasattr(ref_geom, attr)] exported_vals = jnp.array([getattr(exported_geom, attr) for attr in attrs]) ref_vals = jnp.array([getattr(ref_geom, attr) for attr in attrs]) assert_allclose( exported_vals, ref_vals, err_msg=f"Geometry mismatch in {label} model.", atol=1e-6, ) def compare_mass_and_inertia(exported_link, ref_link, label=""): assert_allclose( exported_link.inertial.mass, ref_link.inertial.mass, atol=1e-4, err_msg=f"Mass mismatch in {label} model.", ) assert_allclose( exported_link.inertial.inertia.matrix(), ref_link.inertial.inertia.matrix(), atol=1e-4, err_msg=f"Inertia matrix mismatch in {label} model.", ) def compare_collisions(exported_link, ref_link, label=""): geom_types = ["box", "sphere", "cylinder"] for geom_type in geom_types: exp_geom = getattr(exported_link.collision.geometry, geom_type) ref_geom = getattr(ref_link.collision.geometry, geom_type) if ref_geom is not None: if geom_type == "box": assert_allclose( jnp.array(exp_geom.size), jnp.array(ref_geom.size), atol=1e-6 ) elif geom_type == "sphere": assert_allclose(exp_geom.radius, ref_geom.radius, atol=1e-6) elif geom_type == "cylinder": assert_allclose(exp_geom.radius, ref_geom.radius, atol=1e-6) assert_allclose(exp_geom.length, ref_geom.length, atol=1e-6) return pytest.skip( f"Collision geometry type for link {exported_link.name} not supported." ) def validate_model(updated_model, ref_model, label): urdf = updated_model.export_updated_model() assert isinstance(urdf, str), f"{label}: Exported URDF is not a string." exported_sdf = rod.Sdf.load(urdf, is_urdf=True) assert ( len(exported_sdf.models()) == 1 ), f"{label}: Exported model does not contain exactly one ROD model." exported_model = exported_sdf.models()[0] for link_name in model.link_names(): exported_link = get_link_by_name(exported_model, link_name) ref_link = get_link_by_name(ref_model, link_name) compare_geometries(exported_link, ref_link, label=label) compare_mass_and_inertia(exported_link, ref_link, label=label) compare_collisions(exported_link, ref_link, label=label) # Test both scaled and identity-scaled updates for scaling, label in ( (scaling_parameters, "SCALED"), (identity_scaling, "IDENTITY SCALED"), ): # Load reference ROD model if label == "IDENTITY SCALED": ref_model = rod.Sdf.load(jaxsim_model_garpez.built_from).models()[0] else: ref_model = rod.Sdf.load(jaxsim_model_garpez_scaled.built_from).models()[0] updated_model = js.model.update_hw_parameters(model, scaling) validate_model(updated_model, ref_model, label) def test_hw_parameters_optimization(jaxsim_model_garpez: js.model.JaxSimModel): """ Test that updating hardware parameters allows optimizing the position of a link to match a target value along a specific world axis. """ model = jaxsim_model_garpez data = js.data.JaxSimModelData.build(model=model) # Define the target height for the link. target_height = 3.0 # Get the index of the link to optimize (e.g., "torso"). link_idx = js.link.name_to_idx(model, link_name="link4") # Define the initial hardware parameters (scaling factors). initial_dims = jnp.ones( (model.number_of_links(), 3) ) # Initial dimensions (1.0 for all links). initial_density = jnp.ones( (model.number_of_links(),) ) # Initial density (1.0 for all links). scaling_factors = js.kin_dyn_parameters.ScalingFactors( dims=initial_dims, density=initial_density ) # Define the loss function. def loss(scaling_factors): # Update the model with the new hardware parameters. updated_model = js.model.update_hw_parameters( model=model, scaling_factors=scaling_factors ) # Sync the data's cached kinematics (joint transforms, link transforms, …) # with the updated model geometry before running any dynamics. updated_data = data.replace(model=updated_model) # Compute forward kinematics for the link. W_H_L = js.model.forward_kinematics(model=updated_model, data=updated_data)[ link_idx ] # Extract the height (z-axis position) of the link. link4_height = W_H_L[2, 3] # Assuming z-axis is the third row. # Compute the loss as the squared difference from the target height. return (link4_height - target_height) ** 2 # Compute the gradient of the loss function with respect to the scaling factors. loss_grad = jax.grad(loss) # Perform gradient descent. alpha = 0.01 # Learning rate. num_iterations = 1000 # Number of gradient descent steps. for _ in range(num_iterations): # Compute the gradient. grad_scaling_factors = loss_grad(scaling_factors) # Update the scaling factors. scaling_factors = js.kin_dyn_parameters.ScalingFactors( dims=scaling_factors.dims - alpha * grad_scaling_factors.dims, density=scaling_factors.density - alpha * grad_scaling_factors.density, ) # Compute the current loss value. current_loss = loss(scaling_factors) # Optionally, print the progress. if _ % 100 == 0: print(f"Iteration {_}: Loss = {current_loss}") # Assert that the final loss is close to zero. assert current_loss < 1e-3, "Optimization did not converge to the target height." def test_hw_parameters_collision_scaling( jaxsim_model_box: js.model.JaxSimModel, prng_key: jax.Array ): """ Test that the collision elements of the model are updated correctly during the scaling of the model hw parameters. """ _, subkey = jax.random.split(prng_key, num=2) # TODO: the jaxsim_model_box has an additional frame, which is handled wrongly # during the export of the updated model. For this reason, we recreate the model # from scratch here. del jaxsim_model_box import rod.builder.primitives # Create on-the-fly a ROD model of a box. rod_model = ( rod.builder.primitives.BoxBuilder(x=0.3, y=0.2, z=0.1, mass=1.0, name="box") .build_model() .add_link(name="box_link") .add_inertial() .add_visual() .add_collision() .build() ) model = js.model.JaxSimModel.build_from_model_description( model_description=rod_model ) # Define the scaling factor for the model scaling_factor = 5.0 # Recompute K and D, since the mass is scaled by scaling_factor^3 # and the expected static compression of the terrain is approximately # proportional to mass/K and divided by the 4 contact points. K = model.contact_params.K * (scaling_factor**2) # Strongly overdamped, to avoid oscillations due to the high mass # and the low penetration allowed. D = 8 * jnp.sqrt(K) with model.editable(validate=False) as model: model.contact_params = SoftContactsParams(K=K, D=D) # Define the nominal radius of the sphere nominal_height = model.kin_dyn_parameters.hw_link_metadata.geometry[0, 2] # Define scaling parameters scaling_parameters = ScalingFactors( dims=jnp.ones((model.number_of_links(), 3)) * scaling_factor, density=jnp.array([1.0]), ) # Update the model with the scaling parameters updated_model = js.model.update_hw_parameters(model, scaling_parameters) # Compute the expected height (nominal radius * scaling factor) expected_height = nominal_height * scaling_factor / 2 # Simulate the box falling under gravity data = js.data.JaxSimModelData.build( model=updated_model, # Set the initial position of the box's base to be slightly above the ground # to allow it to settle at the expected height after scaling. # The base position is set to the nominal height of the box scaled by the scaling factor, # plus a small offset to avoid immediate collision with the ground. # This ensures that the box has enough space to fall and settle at the expected height. base_position=jnp.array( [ *jax.random.uniform(subkey, shape=(2,)), expected_height + 0.05, ] ), ) num_steps = 1000 # Number of simulation steps for _ in range(num_steps): data = js.model.step( model=updated_model, data=data, ) # Get the final height of the box's base updated_base_height = data.base_position[2] # Assert that the box settles at the expected height assert jnp.isclose( updated_base_height, expected_height, atol=1e-3 ), f"model base height mismatch: expected {expected_height}, got {updated_base_height}" def test_unsupported_link_cases(): """ Test that unsupported link cases are handled correctly. """ import rod.builder.primitives from jaxsim.api.kin_dyn_parameters import LinkParametrizableShape # Test unsupported (no visual) no_visual_model = js.model.JaxSimModel.build_from_model_description( rod.builder.primitives.BoxBuilder(x=1, y=1, z=1, mass=1, name="no_vis_box") .build_model() .add_link(name="no_visual_link") .add_inertial() .build() # No .add_visual() ) no_visual_metadata = no_visual_model.kin_dyn_parameters.hw_link_metadata empty_metadata = HwLinkMetadata.empty() comparison = jax.tree.map( jnp.allclose, no_visual_metadata, empty_metadata, ) assert jax.tree.reduce( lambda acc, value: acc and bool(value), comparison, True ), "No links should be supported." # Create a simple multi-link URDF and add collision to ensure links are kept multi_link_urdf = """ """ # Build JaxSim model from the URDF multi_link_model = js.model.JaxSimModel.build_from_model_description( multi_link_urdf, is_urdf=True ) multi_link_metadata = multi_link_model.kin_dyn_parameters.hw_link_metadata # Verify array consistency for the model num_links = multi_link_model.number_of_links() assert num_links == 3, f"Expected 3 links in the URDF model, got {num_links}" assert ( len(multi_link_metadata.link_shape) == len(multi_link_metadata.geometry) == len(multi_link_metadata.density) == num_links ) # Count verification in single model supported_count = sum( 1 for s in multi_link_metadata.link_shape if s != LinkParametrizableShape.Unsupported ) unsupported_count = sum( 1 for s in multi_link_metadata.link_shape if s == LinkParametrizableShape.Unsupported ) assert ( supported_count == 2 ), f"Expected 2 supported links in single model, got {supported_count}" assert ( unsupported_count == 1 ), f"Expected 1 unsupported link in single model, got {unsupported_count}" # Ensure shapes match expectations by name link_indices = {name: idx for idx, name in enumerate(multi_link_model.link_names())} assert ( multi_link_metadata.link_shape[link_indices["supported_link"]] == LinkParametrizableShape.Box ), "Supported link should remain a box" assert ( multi_link_metadata.link_shape[link_indices["unsupported_link"]] == LinkParametrizableShape.Unsupported ), "Unsupported link should remain unsupported" double_visual_idx = link_indices["double_visual_link"] assert ( multi_link_metadata.link_shape[double_visual_idx] == LinkParametrizableShape.Sphere ), "Double visual link should pick the first (sphere) visual" assert_allclose( multi_link_metadata.geometry[double_visual_idx, 0], 0.4, err_msg="Sphere radius must match the first visual", ) # Test selective parametrization: only 'supported_link' and 'double_visual_link' should be parametrized selective_model = js.model.JaxSimModel.build_from_model_description( multi_link_urdf, is_urdf=True, parametrized_links=("double_visual_link") ) selective_metadata = selective_model.kin_dyn_parameters.hw_link_metadata # Check that only the selected links are parametrized link_indices = {name: idx for idx, name in enumerate(selective_model.link_names())} assert ( selective_metadata.link_shape[link_indices["supported_link"]] == LinkParametrizableShape.Unsupported ), "Selected supported_link should be parametrized as Box" assert ( selective_metadata.link_shape[link_indices["double_visual_link"]] == LinkParametrizableShape.Sphere ), "Selected double_visual_link should be parametrized as Sphere" assert ( selective_metadata.link_shape[link_indices["unsupported_link"]] == LinkParametrizableShape.Unsupported ), "Non-selected unsupported_link should be marked as Unsupported" def test_export_continuous_joint_handling(): """ Test that continuous joints are correctly exported with type="continuous" and without position limits, while preserving effort and velocity limits. """ # Load cartpole model which has a continuous joint (pivot) cartpole_path = ( pathlib.Path(__file__).parent.parent / "examples" / "assets" / "cartpole.urdf" ) model = js.model.JaxSimModel.build_from_model_description(cartpole_path) # Define some simple scaling parameters (identity scaling) scaling_parameters = ScalingFactors( dims=jnp.ones((model.number_of_links(), 3)), density=jnp.ones(model.number_of_links()), ) # Update the model with scaling parameters updated_model = js.model.update_hw_parameters(model, scaling_parameters) # Export the updated model exported_urdf = updated_model.export_updated_model() # Parse the URDF XML directly (not through rod, which would convert continuous back to revolute) root = ET.fromstring(exported_urdf) # Find the pivot joint (continuous joint) pivot_joint = None for joint_elem in root.findall(".//joint"): if joint_elem.get("name") == "pivot": pivot_joint = joint_elem break assert pivot_joint is not None, "pivot joint should exist in exported model" # Verify that the joint type is "continuous" assert ( pivot_joint.get("type") == "continuous" ), f"pivot joint should have type='continuous', got '{pivot_joint.get('type')}'" # Verify that position limits are not present for continuous joints limit_elem = pivot_joint.find("limit") assert limit_elem is not None, "pivot joint should have limits element" assert ( limit_elem.get("lower") is None ), f"continuous joint should not have lower position limit, got {limit_elem.get('lower')}" assert ( limit_elem.get("upper") is None ), f"continuous joint should not have upper position limit, got {limit_elem.get('upper')}" # Verify that effort and velocity limits are preserved assert ( limit_elem.get("effort") is not None ), "continuous joint should preserve effort limit" assert ( limit_elem.get("velocity") is not None ), "continuous joint should preserve velocity limit" # Verify that the linear joint (prismatic) is NOT changed to continuous linear_joint = None for joint_elem in root.findall(".//joint"): if joint_elem.get("name") == "linear": linear_joint = joint_elem break assert linear_joint is not None, "linear joint should exist in exported model" assert ( linear_joint.get("type") == "prismatic" ), f"linear joint should remain prismatic, got '{linear_joint.get('type')}'" # Prismatic joint should keep its limits linear_limit = linear_joint.find("limit") assert linear_limit is not None, "prismatic joint should have limits" assert ( linear_limit.get("lower") is not None ), "prismatic joint should have lower limit" assert ( linear_limit.get("upper") is not None ), "prismatic joint should have upper limit" def test_export_model_with_missing_collision( jaxsim_model_missing_collision: js.model.JaxSimModel, ): """ Test that export_updated_model() works correctly when a link has a visual but is missing a collision element. This validates the skip logic that handles None collision elements. """ model = jaxsim_model_missing_collision # Define scaling parameters to modify the model scaling_parameters = ScalingFactors( dims=jnp.array( [ [1.5, 1.0, 1.0], # Scale x-dimension for link1 (has collision) [2.0, 1.0, 1.0], # Scale radius for link2 (missing collision) ] ), density=jnp.ones(2), ) # Update the model with scaling parameters updated_model = js.model.update_hw_parameters(model, scaling_parameters) # Export the updated model - this should NOT fail even though link2 is missing collision exported_urdf = updated_model.export_updated_model() # Verify basic structure of exported URDF assert isinstance(exported_urdf, str), "Exported URDF should be a string" assert len(exported_urdf) > 0, "Exported URDF should not be empty" # Parse the exported URDF to verify it's valid XML root = ET.fromstring(exported_urdf) assert root.tag == "robot", "Root element should be 'robot'" # Find both links in the exported model links = {link.get("name"): link for link in root.findall(".//link")} assert "link1" in links, "link1 should exist in exported model" assert "link2" in links, "link2 should exist in exported model" # Verify link1 has both visual and collision link1 = links["link1"] link1_visual = link1.find("visual") link1_collision = link1.find("collision") assert link1_visual is not None, "link1 should have a visual element" assert link1_collision is not None, "link1 should have a collision element" # Verify link1's geometry was updated link1_visual_box = link1_visual.find(".//box") assert link1_visual_box is not None, "link1 visual should have box geometry" link1_size = [float(x) for x in link1_visual_box.get("size").split()] # First dimension should be scaled by 1.5 assert_allclose( link1_size[0], 0.3 * 1.5, atol=1e-6, err_msg="link1 x-dimension should be scaled", ) # Verify link2 has visual but no collision link2 = links["link2"] link2_visual = link2.find("visual") link2_collision = link2.find("collision") assert link2_visual is not None, "link2 should have a visual element" assert link2_collision is None, "link2 should NOT have a collision element" # Verify link2's visual geometry was updated despite missing collision link2_visual_sphere = link2_visual.find(".//sphere") assert link2_visual_sphere is not None, "link2 visual should have sphere geometry" link2_radius = float(link2_visual_sphere.get("radius")) # Radius should be scaled by 2.0 assert_allclose( link2_radius, 0.1 * 2.0, atol=1e-6, err_msg="link2 radius should be scaled despite missing collision", ) # Load the exported model to verify it can be parsed correctly exported_sdf = rod.Sdf.load(exported_urdf, is_urdf=True) assert ( len(exported_sdf.models()) == 1 ), "Exported model should contain exactly one ROD model" exported_model = exported_sdf.models()[0] assert exported_model.name == model.name(), "Exported model name should match" # Verify we can build a JaxSim model from the exported URDF _ = js.model.JaxSimModel.build_from_model_description( model_description=exported_urdf, is_urdf=True ) def test_export_mesh_scaling_preserves_nonzero_visual_and_joint_origins( tmp_path: pathlib.Path, ): """ Regression test for mesh export: non-identity scaling must preserve non-zero visual/joint origins in the URDF. """ mesh_file = pathlib.Path(__file__).parent / "assets" / "cube.stl" if not mesh_file.exists(): pytest.skip(f"Test mesh file not found: {mesh_file}") urdf_path = tmp_path / "mesh_origin_regression.urdf" urdf_path.write_text( f""" """, encoding="utf-8", ) model = js.model.JaxSimModel.build_from_model_description( model_description=urdf_path, is_urdf=True, parametrized_links=("mesh_link",), ) mesh_link_idx = js.link.name_to_idx(model=model, link_name="mesh_link") dims = jnp.ones((model.number_of_links(), 3)) dims = dims.at[mesh_link_idx].set(jnp.array([1.7, 0.8, 1.3])) scaling = ScalingFactors(dims=dims, density=jnp.ones(model.number_of_links())) updated_model = js.model.update_hw_parameters(model=model, scaling_factors=scaling) exported_urdf = updated_model.export_updated_model() root = ET.fromstring(exported_urdf) visual_origin = root.find(".//link[@name='mesh_link']/visual/origin") assert visual_origin is not None, "Mesh visual origin must exist in exported URDF" visual_xyz = np.array([float(v) for v in visual_origin.get("xyz").split()]) expected_visual_xyz = np.array( updated_model.kin_dyn_parameters.hw_link_metadata.L_H_vis[mesh_link_idx][:3, 3] ) assert_allclose( visual_xyz, expected_visual_xyz, atol=1e-8, err_msg="Exported mesh visual origin does not match updated metadata", ) assert not np.allclose(visual_xyz, np.zeros(3), atol=1e-12) joint_origin = root.find(".//joint[@name='base_to_mesh']/origin") assert joint_origin is not None, "Joint origin must exist in exported URDF" joint_xyz = np.array([float(v) for v in joint_origin.get("xyz").split()]) joint_idx = js.joint.name_to_idx(model=updated_model, joint_name="base_to_mesh") expected_joint_xyz = np.array( updated_model.kin_dyn_parameters.joint_model.λ_H_pre[joint_idx + 1][:3, 3] ) assert_allclose( joint_xyz, expected_joint_xyz, atol=1e-8, err_msg="Exported joint origin does not match updated joint transform", ) assert not np.allclose(joint_xyz, np.zeros(3), atol=1e-12) reimported_jaxsim_model = js.model.JaxSimModel.build_from_model_description( model_description=exported_urdf, is_urdf=True ) assert ( reimported_jaxsim_model is not None ), "Should be able to build model from exported URDF" assert ( reimported_jaxsim_model.number_of_links() == model.number_of_links() ), "Reimported model should have same number of links" # ============================================================================= # Mesh Scaling Tests # ============================================================================= def test_mesh_shape_enum(): """Test that the Mesh shape type is available in the enum.""" assert hasattr(LinkParametrizableShape, "Mesh") assert LinkParametrizableShape.Mesh == 3 def test_mixed_shapes_metadata(): """Test loading and metadata verification for mixed primitive and mesh shapes.""" test_urdf = pathlib.Path(__file__).parent / "assets" / "mixed_shapes_robot.urdf" if not test_urdf.exists(): pytest.skip(f"Test URDF not found: {test_urdf}") mesh_file = pathlib.Path(__file__).parent / "assets" / "cube.stl" if not mesh_file.exists(): pytest.skip(f"Test mesh not found: {mesh_file}") # Load model with all link types parametrized model = js.model.JaxSimModel.build_from_model_description( model_description=test_urdf, is_urdf=True, parametrized_links=("box_link", "cylinder_link", "mesh_link", "sphere_link"), ) assert model.name() == "mixed_shapes_robot" assert model.number_of_links() == 4 hw_meta = model.kin_dyn_parameters.hw_link_metadata # Verify all 4 links are parametrized with correct shape types assert len(hw_meta.link_shape) == 4 assert hw_meta.link_shape[0] == LinkParametrizableShape.Box assert hw_meta.link_shape[1] == LinkParametrizableShape.Cylinder assert hw_meta.link_shape[2] == LinkParametrizableShape.Mesh assert hw_meta.link_shape[3] == LinkParametrizableShape.Sphere # Verify mesh data exists only for mesh link assert hw_meta.mesh_vertices is not None assert hw_meta.mesh_vertices[0] is None # box assert hw_meta.mesh_vertices[1] is None # cylinder assert hw_meta.mesh_vertices[2] is not None # mesh assert hw_meta.mesh_vertices[3] is None # sphere assert hw_meta.mesh_faces is not None assert hw_meta.mesh_faces[2] is not None # mesh link has faces def test_mixed_shapes_scaling(): """Test uniform and non-uniform scaling with mixed primitive and mesh shapes.""" test_urdf = pathlib.Path(__file__).parent / "assets" / "mixed_shapes_robot.urdf" if not test_urdf.exists(): pytest.skip(f"Test URDF not found: {test_urdf}") mesh_file = pathlib.Path(__file__).parent / "assets" / "cube.stl" if not mesh_file.exists(): pytest.skip(f"Test mesh not found: {mesh_file}") model = js.model.JaxSimModel.build_from_model_description( model_description=test_urdf, is_urdf=True, parametrized_links=("box_link", "cylinder_link", "mesh_link", "sphere_link"), ) hw_meta = model.kin_dyn_parameters.hw_link_metadata if len(hw_meta.link_shape) == 0: pytest.skip("Hardware parametrization not supported") # Get original masses masses_orig = {} for i in range(model.number_of_links()): link_name = js.link.idx_to_name(model=model, link_index=i) masses_orig[link_name] = float(model.kin_dyn_parameters.link_parameters.mass[i]) # Test uniform scaling (2x), so all links should scaled by 8x uniform_scaling = ScalingFactors( dims=jnp.ones((4, 3)) * 2.0, density=jnp.ones(4), ) scaled_uniform = js.model.update_hw_parameters(model, uniform_scaling) for i in range(scaled_uniform.number_of_links()): link_name = js.link.idx_to_name(model=scaled_uniform, link_index=i) mass_scaled = float(scaled_uniform.kin_dyn_parameters.link_parameters.mass[i]) ratio = mass_scaled / masses_orig[link_name] assert jnp.allclose( ratio, 8.0, rtol=0.1 ), f"Uniform scaling: {link_name} expected 8x, got {ratio:.2f}x" # Test different scaling factors per link different_scaling = ScalingFactors( dims=jnp.array( [ [2.0, 2.0, 2.0], # box: 8x [3.0, 3.0, 3.0], # cylinder: 27x [1.5, 1.5, 1.5], # mesh: 3.375x [2.5, 2.5, 2.5], # sphere: 15.625x ] ), density=jnp.ones(4), ) scaled_different = js.model.update_hw_parameters(model, different_scaling) expected_ratios = { "box_link": 8.0, "cylinder_link": 27.0, "mesh_link": 3.375, "sphere_link": 15.625, } for i in range(scaled_different.number_of_links()): link_name = js.link.idx_to_name(model=scaled_different, link_index=i) mass_scaled = float(scaled_different.kin_dyn_parameters.link_parameters.mass[i]) ratio = mass_scaled / masses_orig[link_name] expected = expected_ratios[link_name] assert jnp.allclose( ratio, expected, rtol=0.1 ), f"Different scaling: {link_name} expected {expected}x, got {ratio:.2f}x" ================================================ FILE: tests/test_automatic_differentiation.py ================================================ import os import jax import jax.numpy as jnp import numpy as np from jax.test_util import check_grads import jaxsim.api as js import jaxsim.math import jaxsim.rbda import jaxsim.typing as jtp from jaxsim import VelRepr from jaxsim.rbda.contacts import SoftContacts, SoftContactsParams from .utils import assert_allclose # All JaxSim algorithms, excluding the variable-step integrators, should support # being automatically differentiated until second order, both in FWD and REV modes. # However, checking the second-order derivatives is particularly slow and makes # CI tests take too long. Therefore, we only check first-order derivatives. AD_ORDER = os.environ.get("JAXSIM_TEST_AD_ORDER", 1) # Define the step size used to compute finite differences depending on the # floating point resolution. ε = os.environ.get( "JAXSIM_TEST_FD_STEP_SIZE", jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3), ) def get_random_data_and_references( model: js.model.JaxSimModel, velocity_representation: VelRepr, key: jax.Array, ) -> tuple[js.data.JaxSimModelData, js.references.JaxSimModelReferences]: key, subkey = jax.random.split(key, num=2) data = js.data.random_model_data( model=model, key=subkey, velocity_representation=velocity_representation ) _, subkey1, subkey2 = jax.random.split(key, num=3) references = js.references.JaxSimModelReferences.build( model=model, joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)), link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)), data=data, velocity_representation=velocity_representation, ) # Remove the force applied to the base link if the model is fixed-base. if not model.floating_base(): references = references.apply_link_forces( forces=jnp.atleast_2d(jnp.zeros(6)), model=model, data=data, link_names=(model.base_link(),), additive=False, ) return data, references def test_ad_aba( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data, references = get_random_data_and_references( model=model, velocity_representation=VelRepr.Inertial, key=subkey ) # Get the standard gravity constant. g = jaxsim.math.STANDARD_GRAVITY # State in VelRepr.Inertial representation. W_p_B = data.base_position W_Q_B = data.base_orientation s = data.joint_positions W_v_WB = data.base_velocity ṡ = data.joint_velocities # Inputs. W_f_L = references.link_forces(model=model) τ = references.joint_force_references(model=model) # ==== # Test # ==== # Get a closure exposing only the parameters to be differentiated. def aba(W_p_B, W_Q_B, s, W_v_WB, ṡ, τ, W_f_L, g): import jaxlie W_H_B = jaxlie.SE3.from_rotation_and_translation( rotation=jaxlie.SO3(wxyz=W_Q_B / jnp.linalg.norm(W_Q_B)), translation=W_p_B, ).as_matrix() joint_transforms = model.kin_dyn_parameters.joint_transforms( joint_positions=s, base_transform=W_H_B ) return jaxsim.rbda.aba( model=model, base_position=W_p_B, base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B), joint_positions=s, base_linear_velocity=W_v_WB[0:3], base_angular_velocity=W_v_WB[3:6], joint_velocities=ṡ, joint_transforms=joint_transforms, joint_forces=τ, link_forces=W_f_L, standard_gravity=g, ) # Check derivatives against finite differences. check_grads( f=aba, args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, τ, W_f_L, g), order=AD_ORDER, modes=["rev", "fwd"], eps=ε, ) def test_ad_aba_parallel( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data, references = get_random_data_and_references( model=model, velocity_representation=VelRepr.Inertial, key=subkey ) g = jaxsim.math.STANDARD_GRAVITY W_p_B = data.base_position W_Q_B = data.base_orientation s = data.joint_positions W_v_WB = data.base_velocity ṡ = data.joint_velocities i_X_λi = data._joint_transforms W_f_L = references.link_forces(model=model) τ = references.joint_force_references(model=model) # Verify parallel ABA matches sequential ABA. result_seq = jaxsim.rbda.aba( model=model, base_position=W_p_B, base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B), joint_positions=s, base_linear_velocity=W_v_WB[0:3], base_angular_velocity=W_v_WB[3:6], joint_velocities=ṡ, joint_transforms=i_X_λi, joint_forces=τ, link_forces=W_f_L, standard_gravity=g, ) result_par = jaxsim.rbda.aba_parallel( model=model, base_position=W_p_B, base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B), joint_positions=s, base_linear_velocity=W_v_WB[0:3], base_angular_velocity=W_v_WB[3:6], joint_velocities=ṡ, joint_forces=τ, joint_transforms=i_X_λi, link_forces=W_f_L, standard_gravity=g, ) assert_allclose(result_seq[0], result_par[0], atol=1e-10) assert_allclose(result_seq[1], result_par[1], atol=1e-10) # Check derivatives against finite differences. aba_par = lambda W_p_B, W_Q_B, s, W_v_WB, ṡ, i_X_λi, τ, W_f_L, g: jaxsim.rbda.aba_parallel( model=model, base_position=W_p_B, base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B), joint_positions=s, base_linear_velocity=W_v_WB[0:3], base_angular_velocity=W_v_WB[3:6], joint_velocities=ṡ, joint_forces=τ, joint_transforms=i_X_λi, link_forces=W_f_L, standard_gravity=g, ) check_grads( f=aba_par, args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, i_X_λi, τ, W_f_L, g), order=AD_ORDER, modes=["rev", "fwd"], eps=ε, ) def test_ad_rnea( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, ): model = jaxsim_models_types key, subkey = jax.random.split(prng_key, num=2) data, references = get_random_data_and_references( model=model, velocity_representation=VelRepr.Inertial, key=subkey ) # Get the standard gravity constant. g = jaxsim.math.STANDARD_GRAVITY # State in VelRepr.Inertial representation. W_p_B = data.base_position W_Q_B = data.base_orientation s = data.joint_positions W_v_WB = data.base_velocity ṡ = data.joint_velocities i_X_λi = data._joint_transforms # Inputs. W_f_L = references.link_forces(model=model) # ==== # Test # ==== _, subkey1, subkey2 = jax.random.split(key, num=3) W_v̇_WB = jax.random.uniform(subkey1, shape=(6,), minval=-1) s̈ = jax.random.uniform(subkey2, shape=(model.dofs(),), minval=-1) # Get a closure exposing only the parameters to be differentiated. rnea = lambda W_p_B, W_Q_B, s, W_v_WB, ṡ, W_v̇_WB, s̈, i_X_λi, W_f_L, g: jaxsim.rbda.rnea( model=model, base_position=W_p_B, base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B), joint_positions=s, base_linear_velocity=W_v_WB[0:3], base_angular_velocity=W_v_WB[3:6], joint_velocities=ṡ, base_linear_acceleration=W_v̇_WB[0:3], base_angular_acceleration=W_v̇_WB[3:6], joint_accelerations=s̈, joint_transforms=i_X_λi, link_forces=W_f_L, standard_gravity=g, ) # Check derivatives against finite differences. check_grads( f=rnea, args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, W_v̇_WB, s̈, i_X_λi, W_f_L, g), order=AD_ORDER, modes=["rev", "fwd"], eps=ε, ) def test_ad_crba( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data, _ = get_random_data_and_references( model=model, velocity_representation=VelRepr.Inertial, key=subkey ) # State in VelRepr.Inertial representation. s = data.joint_positions # ==== # Test # ==== # Get a closure exposing only the parameters to be differentiated. crba = lambda s: jaxsim.rbda.crba(model=model, joint_positions=s) # Check derivatives against finite differences. check_grads( f=crba, args=(s,), order=AD_ORDER, modes=["rev", "fwd"], eps=ε, ) def test_ad_fk( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data, _ = get_random_data_and_references( model=model, velocity_representation=VelRepr.Inertial, key=subkey ) # State in VelRepr.Inertial representation. W_p_B = data.base_position W_Q_B = data.base_orientation s = data.joint_positions W_v_lin = data._base_linear_velocity W_v_ang = data._base_angular_velocity ṡ = data.joint_velocities # ==== # Test # ==== # Get a closure exposing only the parameters to be differentiated. def fk(W_p_B, W_Q_B, s, W_v_lin, W_v_ang, ṡ): import jaxlie W_H_B = jaxlie.SE3.from_rotation_and_translation( rotation=jaxlie.SO3(wxyz=W_Q_B / jnp.linalg.norm(W_Q_B)), translation=W_p_B, ).as_matrix() joint_transforms = model.kin_dyn_parameters.joint_transforms( joint_positions=s, base_transform=W_H_B ) return jaxsim.rbda.forward_kinematics_model( model=model, base_position=W_p_B, base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B), joint_positions=s, base_linear_velocity_inertial=W_v_lin, base_angular_velocity_inertial=W_v_ang, joint_velocities=ṡ, joint_transforms=joint_transforms, ) # Check derivatives against finite differences. check_grads( f=fk, args=(W_p_B, W_Q_B, s, W_v_lin, W_v_ang, ṡ), order=AD_ORDER, modes=["rev", "fwd"], eps=ε, ) def test_ad_jacobian( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data, _ = get_random_data_and_references( model=model, velocity_representation=VelRepr.Inertial, key=subkey ) # State in VelRepr.Inertial representation. s = data.joint_positions # ==== # Test # ==== # Get the link indices. link_indices = jnp.arange(model.number_of_links()) # Get a closure exposing only the parameters to be differentiated. # We differentiate the jacobian of the last link, likely among those # farther from the base. jacobian = lambda s: jaxsim.rbda.jacobian( model=model, joint_positions=s, link_index=link_indices[-1] ) # Check derivatives against finite differences. check_grads( f=jacobian, args=(s,), order=AD_ORDER, modes=["rev", "fwd"], eps=ε, ) def test_ad_soft_contacts( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, ): model = jaxsim_models_types with model.editable(validate=False) as model: model.contact_model = jaxsim.rbda.contacts.SoftContacts.build() _, subkey1, subkey2, subkey3 = jax.random.split(prng_key, num=4) p = jax.random.uniform(subkey1, shape=(3,), minval=-1) v = jax.random.uniform(subkey2, shape=(3,), minval=-1) m = jax.random.uniform(subkey3, shape=(3,), minval=-1) # Get the soft contacts parameters. parameters = js.contact.estimate_good_contact_parameters(model=model) # ==== # Test # ==== # Get a closure exposing only the parameters to be differentiated. def close_over_inputs_and_parameters( p: jtp.VectorLike, v: jtp.VectorLike, m: jtp.VectorLike, params: SoftContactsParams, ) -> tuple[jtp.Vector, jtp.Vector]: W_f_Ci, CW_ṁ = SoftContacts.compute_contact_force( position=p, velocity=v, tangential_deformation=m, parameters=params, terrain=model.terrain, ) return W_f_Ci, CW_ṁ # Check derivatives against finite differences. check_grads( f=close_over_inputs_and_parameters, args=(p, v, m, parameters), order=AD_ORDER, modes=["rev", "fwd"], eps=ε, # On GPU, the tolerance needs to be increased. rtol=0.02 if "gpu" in {d.platform for d in p.devices()} else None, ) def test_ad_integration( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, ): model = jaxsim_models_types _, subkey = jax.random.split(prng_key, num=2) data, references = get_random_data_and_references( model=model, velocity_representation=VelRepr.Inertial, key=subkey ) # State in VelRepr.Inertial representation. W_p_B = data.base_position W_Q_B = data.base_orientation s = data.joint_positions W_v_WB = data.base_velocity ṡ = data.joint_velocities # Inputs. W_f_L = references.link_forces(model=model) τ = references.joint_force_references(model=model) # ==== # Test # ==== # Function exposing only the parameters to be differentiated. def step( W_p_B: jtp.Vector, W_Q_B: jtp.Vector, s: jtp.Vector, W_v_WB: jtp.Vector, ṡ: jtp.Vector, τ: jtp.Vector, W_f_L: jtp.Matrix, ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: # When JAX tests against finite differences, the injected ε will make the # quaternion non-unitary, which will cause the AD check to fail. W_Q_B = W_Q_B / jnp.linalg.norm(W_Q_B) data_x0 = data.replace( model=model, base_position=W_p_B, base_quaternion=W_Q_B, joint_positions=s, base_linear_velocity=W_v_WB[0:3], base_angular_velocity=W_v_WB[3:6], joint_velocities=ṡ, ) data_xf = js.model.step( model=model, data=data_x0, joint_force_references=τ, link_forces=W_f_L, ) xf_W_p_B = data_xf.base_position xf_W_Q_B = data_xf.base_orientation xf_s = data_xf.joint_positions xf_W_v_WB = data_xf.base_velocity xf_ṡ = data_xf.joint_velocities return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ # Check derivatives against finite differences. check_grads( f=step, args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, τ, W_f_L), order=AD_ORDER, modes=["fwd", "rev"], eps=ε, ) def test_ad_safe_norm( prng_key: jax.Array, ): _, subkey = jax.random.split(prng_key, num=2) array = jax.random.uniform(subkey, shape=(4,), minval=-5, maxval=5) # ==== # Test # ==== # Test that the safe_norm function is compatible with batching. array = jnp.stack([array, array]) assert jaxsim.math.safe_norm(array, axis=-1).shape == (2,) # Test that the safe_norm function is correctly computing the norm. assert_allclose( jaxsim.math.safe_norm(array, axis=-1), np.linalg.norm(array, axis=-1) ) # Function exposing only the parameters to be differentiated. def safe_norm(array: jtp.Array) -> jtp.Array: return jaxsim.math.safe_norm(array, axis=-1) # Check derivatives against finite differences. check_grads( f=safe_norm, args=(array,), order=AD_ORDER, modes=["rev", "fwd"], eps=ε, ) # Check derivatives against finite differences when the array is zero. check_grads( f=safe_norm, args=(jnp.zeros_like(array),), order=AD_ORDER, modes=["rev", "fwd"], eps=ε, ) def test_ad_hw_parameters( jaxsim_model_garpez: js.model.JaxSimModel, prng_key: jax.Array, ): """ Test the automatic differentiation capability for hardware parameters of the model links. """ model = jaxsim_model_garpez data = js.data.JaxSimModelData.build(model=model) min_val, max_val = 0.5, 10.0 # Generate random scaling factors for testing. _, subkey1, subkey2 = jax.random.split(prng_key, num=3) dims_scaling = jax.random.uniform( subkey1, shape=(model.number_of_links(), 3), minval=min_val, maxval=max_val ) density_scaling = jax.random.uniform( subkey2, shape=(model.number_of_links(),), minval=min_val, maxval=max_val ) scaling_factors = js.kin_dyn_parameters.ScalingFactors( dims=dims_scaling, density=density_scaling ) link_idx = js.link.name_to_idx(model, link_name="link4") # Define a function that updates hardware parameters and computes FK for link 4. def update_hw_params_and_compute_fk_and_mass( scaling_factors: js.kin_dyn_parameters.ScalingFactors, ): # Update hardware parameters. updated_model = js.model.update_hw_parameters( model=model, scaling_factors=scaling_factors ) # Compute forward kinematics for link 4. W_H_L4 = js.model.forward_kinematics(model=updated_model, data=data)[link_idx] # Compute the free floating mass matrix of the updated model. M = js.model.free_floating_mass_matrix(updated_model, data) return W_H_L4[:3, 3], M # Check derivatives against finite differences. check_grads( f=update_hw_params_and_compute_fk_and_mass, args=(scaling_factors,), order=AD_ORDER, modes=["fwd", "rev"], eps=ε, ) ================================================ FILE: tests/test_benchmark.py ================================================ from collections.abc import Callable import jax import jax.numpy as jnp import pytest import jaxsim import jaxsim.api as js from jaxsim.api.kin_dyn_parameters import ScalingFactors def vectorize_data(model: js.model.JaxSimModel, batch_size: int): key = jax.random.PRNGKey(seed=0) keys = jax.random.split(key, num=batch_size) return jax.vmap( lambda key: js.data.random_model_data( model=model, key=key, ) )(keys) def benchmark_test_function( func: Callable, model: js.model.JaxSimModel, benchmark, batch_size ): """Reusability wrapper for benchmark tests.""" data = vectorize_data(model=model, batch_size=batch_size) # Warm-up call to avoid including compilation time jax.vmap(func, in_axes=(None, 0))(model, data) # Benchmark the function call # Note: jax.block_until_ready is used to ensure that the benchmark is not measuring only the asynchronous dispatch benchmark(jax.block_until_ready(jax.vmap(func, in_axes=(None, 0))), model, data) @pytest.mark.benchmark def test_forward_dynamics_aba( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size ): model = jaxsim_model_ergocub_reduced benchmark_test_function(js.model.forward_dynamics_aba, model, benchmark, batch_size) @pytest.mark.benchmark def test_free_floating_bias_forces( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size ): model = jaxsim_model_ergocub_reduced benchmark_test_function( js.model.free_floating_bias_forces, model, benchmark, batch_size ) @pytest.mark.benchmark def test_forward_kinematics( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size ): model = jaxsim_model_ergocub_reduced benchmark_test_function(js.model.forward_kinematics, model, benchmark, batch_size) @pytest.mark.benchmark def test_free_floating_mass_matrix( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size ): model = jaxsim_model_ergocub_reduced benchmark_test_function( js.model.free_floating_mass_matrix, model, benchmark, batch_size ) @pytest.mark.benchmark def test_free_floating_jacobian( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size ): model = jaxsim_model_ergocub_reduced benchmark_test_function( js.model.generalized_free_floating_jacobian, model, benchmark, batch_size ) @pytest.mark.benchmark def test_free_floating_jacobian_derivative( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size ): model = jaxsim_model_ergocub_reduced benchmark_test_function( js.model.generalized_free_floating_jacobian_derivative, model, benchmark, batch_size, ) @pytest.mark.benchmark def test_soft_contact_model( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size ): model = jaxsim_model_ergocub_reduced with model.editable(validate=False) as model: model.contact_model = jaxsim.rbda.contacts.SoftContacts() model.contact_params = js.contact.estimate_good_contact_parameters(model=model) benchmark_test_function(js.ode.system_dynamics, model, benchmark, batch_size) @pytest.mark.benchmark def test_rigid_contact_model( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size ): model = jaxsim_model_ergocub_reduced with model.editable(validate=False) as model: model.contact_model = jaxsim.rbda.contacts.RigidContacts() model.contact_params = js.contact.estimate_good_contact_parameters(model=model) benchmark_test_function(js.ode.system_dynamics, model, benchmark, batch_size) @pytest.mark.benchmark def test_relaxed_rigid_contact_model( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size ): model = jaxsim_model_ergocub_reduced with model.editable(validate=False) as model: model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts() model.contact_params = js.contact.estimate_good_contact_parameters(model=model) benchmark_test_function(js.ode.system_dynamics, model, benchmark, batch_size) @pytest.mark.benchmark def test_simulation_step( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size ): model = jaxsim_model_ergocub_reduced with model.editable(validate=False) as model: model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts() model.contact_params = js.contact.estimate_good_contact_parameters(model=model) benchmark_test_function(js.model.step, model, benchmark, batch_size) @pytest.mark.benchmark def test_update_hw_parameters( jaxsim_model_garpez: js.model.JaxSimModel, benchmark, batch_size ): """Benchmark hardware parameter scaling/update operation (vmapped).""" model = jaxsim_model_garpez n_links = model.number_of_links() # Create a function that generates random scaling factors and updates the model def update_with_random_scaling(key: jax.Array) -> js.model.JaxSimModel: # Generate scaling factors in a reasonable range [0.8, 1.2] dims_scale = jax.random.uniform(key, shape=(n_links, 3), minval=0.8, maxval=1.2) density_scale = jax.random.uniform( jax.random.fold_in(key, 1), shape=(n_links,), minval=0.8, maxval=1.2 ) scaling_factors = ScalingFactors(dims=dims_scale, density=density_scale) return js.model.update_hw_parameters(model, scaling_factors) # Generate batch of random keys key = jax.random.PRNGKey(seed=42) keys = jax.random.split(key, num=batch_size) # Warm-up call to avoid including compilation time jax.vmap(update_with_random_scaling)(keys) # Benchmark the vmapped update operation benchmark(jax.block_until_ready(jax.vmap(update_with_random_scaling)), keys) @pytest.mark.benchmark def test_export_updated_model( jaxsim_model_garpez: js.model.JaxSimModel, benchmark, batch_size ): """Benchmark model export after hardware parameter update.""" model = jaxsim_model_garpez n_links = model.number_of_links() # Create multiple scaled models for benchmarking # Each with slightly different scaling to simulate realistic scenarios key = jax.random.PRNGKey(seed=42) scaling_values = jax.random.uniform( key, shape=(batch_size,), minval=0.9, maxval=1.2 ) updated_models = [] for scale_value in scaling_values: scaling_factors = ScalingFactors( dims=jnp.ones((n_links, 3)) * float(scale_value), density=jnp.ones(n_links), ) updated_models.append(js.model.update_hw_parameters(model, scaling_factors)) # Benchmark the export operation (sequentially for all models) # Note: This is not JIT-compiled since it returns a string (URDF/SDF) def export_all(): return [m.export_updated_model() for m in updated_models] benchmark(export_all) ================================================ FILE: tests/test_exceptions.py ================================================ import io from contextlib import redirect_stdout import chex import jax import jax.numpy as jnp import pytest from jax.errors import JaxRuntimeError from jaxsim import exceptions def test_exceptions_in_jit_functions(): msg_during_jit = "Compiling jit_compiled_function" @jax.jit @chex.assert_max_traces(n=1) def jit_compiled_function(data: jax.Array) -> jax.Array: # This message is compiled only during JIT compilation. print(msg_during_jit) # Condition that will trigger the exception. failed_if_42_plus = jnp.allclose(data, 42) # Raise a ValueError if the condition is met. # The fmt string is built from kwargs. exceptions.raise_value_error_if( condition=failed_if_42_plus, msg="Raising ValueError since data={num}", num=data, ) # Condition that will trigger the exception. failed_if_42_minus = jnp.allclose(data, -42) # Raise a RuntimeError if the condition is met. # The fmt string is built from args. exceptions.raise_runtime_error_if( failed_if_42_minus, "Raising RuntimeError since data={}", data, ) return data # In the first call, the function will be compiled and print the message. with jax.log_compiles(), io.StringIO() as buf, redirect_stdout(buf): data = 40 out = jit_compiled_function(data=data) stdout = buf.getvalue() assert out == data assert msg_during_jit in stdout # In the second call, the function won't be compiled and won't print the message. with jax.log_compiles(), io.StringIO() as buf, redirect_stdout(buf): data = 41 out = jit_compiled_function(data=data) stdout = buf.getvalue() assert out == data assert msg_during_jit not in stdout # Let's trigger a ValueError exception by passing 42. data = 42 with pytest.raises( JaxRuntimeError, match=f"ValueError: Raising ValueError since data={data}", ): _ = jit_compiled_function(data=data) # Let's trigger a RuntimeError exception by passing -42. data = -42 with pytest.raises( JaxRuntimeError, match=f"RuntimeError: Raising RuntimeError since data={data}", ): _ = jit_compiled_function(data=data) ================================================ FILE: tests/test_meshes.py ================================================ import trimesh from jaxsim.parsers.rod import meshes def test_mesh_wrapping_vertex_extraction(): """ Test the vertex extraction method on different meshes. 1. A simple box. 2. A sphere. """ # Test 1: A simple box. # First, create a box with origin at (0,0,0) and extents (3,3,3), # i.e. points span from -1.5 to 1.5 on the axis. mesh = trimesh.creation.box( extents=[3.0, 3.0, 3.0], ) points = meshes.extract_points_vertices(mesh=mesh) assert len(points) == len(mesh.vertices) # Test 2: A sphere. # The sphere is centered at the origin and has a radius of 1.0. mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0) points = meshes.extract_points_vertices(mesh=mesh) assert len(points) == len(mesh.vertices) def test_mesh_wrapping_aap(): """ Test the AAP wrapping method on different meshes. 1. A simple box 1.1: Remove all points above x=0.0 1.2: Remove all points below y=0.0 2. A sphere """ # Test 1.1: Remove all points above x=0.0. # The expected result is that the number of points is halved. # First, create a box with origin at (0,0,0) and extents (3,3,3), # i.e. points span from -1.5 to 1.5 on the axis. mesh = trimesh.creation.box(extents=[3.0, 3.0, 3.0]) points = meshes.extract_points_aap(mesh=mesh, axis="x", lower=0.0) assert len(points) == len(mesh.vertices) // 2 assert all(points[:, 0] > 0.0) # Test 1.2: Remove all points below y=0.0. # The expected result is that the number of points is halved. points = meshes.extract_points_aap(mesh=mesh, axis="y", upper=0.0) assert len(points) == len(mesh.vertices) // 2 assert all(points[:, 1] < 0.0) # Test 2: A sphere. # The sphere is centered at the origin and has a radius of 1.0. # Points are expected to be halved. mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0) # Remove all points above y=0.0. points = meshes.extract_points_aap(mesh=mesh, axis="y", lower=0.0) assert all(points[:, 1] >= 0.0) assert len(points) < len(mesh.vertices) def test_mesh_wrapping_points_over_axis(): """ Test the points over axis method on different meshes. 1. A simple box 1.1: Select 10 points from the lower end of the x-axis 1.2: Select 10 points from the higher end of the y-axis 2. A sphere """ # Test 1.1: Remove 10 points from the lower end of the x-axis. # First, create a box with origin at (0,0,0) and extents (3,3,3), # i.e. points span from -1.5 to 1.5 on the axis. mesh = trimesh.creation.box(extents=[3.0, 3.0, 3.0]) points = meshes.extract_points_select_points_over_axis( mesh=mesh, axis="x", direction="lower", n=4 ) assert len(points) == 4 assert all(points[:, 0] < 0.0) # Test 1.2: Select 10 points from the higher end of the y-axis. points = meshes.extract_points_select_points_over_axis( mesh=mesh, axis="y", direction="higher", n=4 ) assert len(points) == 4 assert all(points[:, 1] > 0.0) # Test 2: A sphere. # The sphere is centered at the origin and has a radius of 1.0. mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0) sphere_n_vertices = len(mesh.vertices) # Select 10 points from the higher end of the z-axis. points = meshes.extract_points_select_points_over_axis( mesh=mesh, axis="z", direction="higher", n=sphere_n_vertices // 2 ) assert len(points) == sphere_n_vertices // 2 assert all(points[:, 2] >= 0.0) ================================================ FILE: tests/test_pytree.py ================================================ import io import pathlib from contextlib import redirect_stdout import chex import jax import jax.numpy as jnp import pytest import jaxsim.api as js def test_call_jit_compiled_function_passing_different_objects( ergocub_model_description_path: pathlib.Path, jaxsim_model_box ): # Create a first model from the URDF. ergocub_model1 = js.model.JaxSimModel.build_from_model_description( model_description=ergocub_model_description_path ) # Create a second model from the URDF. ergocub_model2 = js.model.JaxSimModel.build_from_model_description( model_description=ergocub_model_description_path ) box_model = jaxsim_model_box # The objects should be different, but the comparison should return True. assert id(ergocub_model1) != id(ergocub_model2) assert ergocub_model1 == ergocub_model2 assert hash(ergocub_model1) == hash(ergocub_model2) # If this function has never been compiled by any other test, JAX will # jit-compile it here. _ = js.contact.estimate_good_contact_parameters(model=ergocub_model1) # Now JAX should not compile it again. with jax.log_compiles(), io.StringIO() as buf, redirect_stdout(buf): # Beyond running without any JIT recompilations, the following function # should work on different JaxSimModel objects without raising any errors # related to the comparison of Static fields. _ = js.contact.estimate_good_contact_parameters(model=ergocub_model2) stdout = buf.getvalue() assert ( f"Compiling {js.contact.estimate_good_contact_parameters.__name__}" not in stdout ) # Define a new JIT-compiled function and check that is not recompiled for # different model objects having the same pytree structure. @jax.jit @chex.assert_max_traces(n=1) def my_jit_function(model: js.model.JaxSimModel, data: js.data.JaxSimModelData): # Return random elements from model and data, just to have something returned. return ( jnp.sum(model.kin_dyn_parameters.link_parameters.mass), data.base_position, ) data1 = js.data.JaxSimModelData.build(model=ergocub_model1) _ = my_jit_function(model=ergocub_model1, data=data1) # This should not retrace the function, as ergocub_model2 has the same # pytree structure as ergocub_model1. _ = my_jit_function(model=ergocub_model2, data=data1) # Calling the function with a different model object will retrace it, as # expected. Therefore, an AssertionError should be raised. with pytest.raises( AssertionError, match="Function 'my_jit_function' is traced > 1 times!" ): data3 = js.data.JaxSimModelData.build(model=box_model) _ = my_jit_function(model=box_model, data=data3) ================================================ FILE: tests/test_simulations.py ================================================ import jax import jax.numpy as jnp import numpy as np import pytest import jaxsim.api as js import jaxsim.rbda import jaxsim.typing as jtp from jaxsim import VelRepr from jaxsim.api.kin_dyn_parameters import ConstraintType from .utils import assert_allclose def test_box_with_external_forces( jaxsim_model_box: js.model.JaxSimModel, velocity_representation: VelRepr, ): """ Simulate a box falling due to gravity. We apply to its CoM a 6D force that balances exactly the gravitational force. The box should not fall. """ model = jaxsim_model_box # Build the data of the model. data0 = js.data.JaxSimModelData.build( model=model, base_position=jnp.array([0.0, 0.0, 0.5]), velocity_representation=velocity_representation, ) # Compute the force due to gravity at the CoM. mg = -model.gravity * js.model.total_mass(model=model) G_f = jnp.array([0.0, 0.0, mg, 0, 0, 0]) # Compute the position of the CoM expressed in the coordinates of the link frame L. L_p_CoM = js.link.com_position( model=model, data=data0, link_index=0, in_link_frame=True ) # Compute the transform of 6D forces from the CoM to the link frame. L_H_G = jaxsim.math.Transform.from_quaternion_and_translation(translation=L_p_CoM) G_Xv_L = jaxsim.math.Adjoint.from_transform(transform=L_H_G, inverse=True) L_Xf_G = G_Xv_L.T L_f = L_Xf_G @ G_f # Initialize a references object that simplifies handling external forces. references = js.references.JaxSimModelReferences.build( model=model, data=data0, velocity_representation=velocity_representation, ) # Apply a link forces to the base link. with references.switch_velocity_representation(VelRepr.Body): references = references.apply_link_forces( forces=jnp.atleast_2d(L_f), link_names=model.link_names()[0:1], model=model, data=data0, additive=False, ) # Initialize the simulation horizon. tf = 0.5 T_ns = jnp.arange(start=0, stop=tf * 1e9, step=model.time_step * 1e9, dtype=int) # Copy the initial data... data = data0.copy() # ... and step the simulation. for _ in T_ns: data = js.model.step( model=model, data=data, link_forces=references.link_forces(model, data), ) # Check that the box didn't move. assert_allclose(data.base_position, data0.base_position) assert_allclose(data.base_orientation, data0.base_orientation) def test_box_with_zero_gravity( jaxsim_model_box: js.model.JaxSimModel, velocity_representation: VelRepr, prng_key: jnp.ndarray, ): model = jaxsim_model_box # Move the terrain (almost) infinitely far away from the box. with model.editable(validate=False) as model: model.terrain = jaxsim.terrain.FlatTerrain.build(height=-1e9) model.gravity = 0.0 # Split the PRNG key. _, subkey = jax.random.split(prng_key, num=2) # Build the data of the model. data0 = js.data.JaxSimModelData.build( model=model, base_position=jax.random.uniform(subkey, shape=(3,)), velocity_representation=velocity_representation, ) # Initialize a references object that simplifies handling external forces. references = js.references.JaxSimModelReferences.build( model=model, data=data0, velocity_representation=velocity_representation, ) # Apply a link forces to the base link. with references.switch_velocity_representation(jaxsim.VelRepr.Mixed): # Generate a random linear force. # We enforce them to be the same for all velocity representations so that # we can compare their outcomes. LW_f = 10.0 * ( jax.random.uniform(jax.random.key(0), shape=(model.number_of_links(), 6)) .at[:, 3:] .set(jnp.zeros(3)) ) # Note that the context manager does not switch back the newly created # `references` (that is not the yielded object) to the original representation. # In the simulation loop below, we need to make sure that we switch both `data` # and `references` to the same representation before extracting the information # passed to the step function. references = references.apply_link_forces( forces=jnp.atleast_2d(LW_f), link_names=model.link_names(), model=model, data=data0, additive=False, ) tf = 0.01 T = jnp.arange(start=0, stop=tf * 1e9, step=model.time_step * 1e9, dtype=int) # Copy the initial data... data = data0.copy() # ... and step the simulation. for _ in T: with ( data.switch_velocity_representation(velocity_representation), references.switch_velocity_representation(velocity_representation), ): data = js.model.step( model=model, data=data, link_forces=references.link_forces(model=model, data=data), ) # Check that the box moved as expected. assert_allclose( data.base_position, data0.base_position + 0.5 * LW_f[:, :3].squeeze() / js.model.total_mass(model=model) * tf**2, atol=1e-3, ) def run_simulation( model: js.model.JaxSimModel, data_t0: js.data.JaxSimModelData, tf: jtp.FloatLike, ) -> js.data.JaxSimModelData: # Initialize the integration horizon. T_ns = jnp.arange( start=0.0, stop=int(tf * 1e9), step=int(model.time_step * 1e9) ).astype(int) # Initialize the simulation data. data = data_t0.copy() for _ in T_ns: data = js.model.step( model=model, data=data, ) return data def test_simulation_with_soft_contacts( jaxsim_model_box: js.model.JaxSimModel, integrator ): model = jaxsim_model_box # Define the maximum penetration of each collidable point at steady state. max_penetration = 0.001 with model.editable(validate=False) as model: model.contact_model = jaxsim.rbda.contacts.SoftContacts.build() model.contact_params = js.contact.estimate_good_contact_parameters( model=model, number_of_active_collidable_points_steady_state=4, static_friction_coefficient=1.0, damping_ratio=1.0, max_penetration=max_penetration, ) # Enable a subset of the collidable points. enabled_collidable_points_mask = np.zeros( len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool ) enabled_collidable_points_mask[[0, 1, 2, 3]] = True model.kin_dyn_parameters.contact_parameters.enabled = tuple( enabled_collidable_points_mask.tolist() ) assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4 # Check jaxsim_model_box@conftest.py. box_height = 0.1 # Build the data of the model. data_t0 = js.data.JaxSimModelData.build( model=model, base_position=jnp.array([0.0, 0.0, box_height * 2]), velocity_representation=VelRepr.Inertial, ) # =========================================== # Run the simulation and test the final state # =========================================== data_tf = run_simulation(model=model, data_t0=data_t0, tf=1.0) assert_allclose(data_tf.base_position[0:2], data_t0.base_position[0:2]) assert_allclose(data_tf.base_position[2] + max_penetration, box_height / 2) def test_simulation_with_rigid_contacts( jaxsim_model_box: js.model.JaxSimModel, integrator ): model = jaxsim_model_box with model.editable(validate=False) as model: # In order to achieve almost no penetration, we need to use a fairly large # Baumgarte stabilization term. model.contact_model = jaxsim.rbda.contacts.RigidContacts.build( solver_options={"solver_tol": 1e-3} ) model.contact_params = model.contact_model._parameters_class(K=1e5) # Enable a subset of the collidable points. enabled_collidable_points_mask = np.zeros( len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool ) enabled_collidable_points_mask[[0, 1, 2, 3]] = True model.kin_dyn_parameters.contact_parameters.enabled = tuple( enabled_collidable_points_mask.tolist() ) assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4 # Initialize the maximum penetration of each collidable point at steady state. # This model is rigid, so we expect (almost) no penetration. max_penetration = 0.000 # Check jaxsim_model_box@conftest.py. box_height = 0.1 # Build the data of the model. data_t0 = js.data.JaxSimModelData.build( model=model, base_position=jnp.array([0.0, 0.0, box_height * 2]), velocity_representation=VelRepr.Inertial, ) # =========================================== # Run the simulation and test the final state # =========================================== data_tf = run_simulation(model=model, data_t0=data_t0, tf=1.0) assert_allclose(data_tf.base_position[0:2], data_t0.base_position[0:2]) assert_allclose(data_tf.base_position[2] + max_penetration, box_height / 2) def test_simulation_with_relaxed_rigid_contacts( jaxsim_model_box: js.model.JaxSimModel, integrator ): model = jaxsim_model_box with model.editable(validate=False) as model: model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts.build( solver_options={"tol": 1e-3}, ) model.contact_params = model.contact_model._parameters_class() # Enable a subset of the collidable points. enabled_collidable_points_mask = np.zeros( len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool ) enabled_collidable_points_mask[[0, 1, 2, 3]] = True model.kin_dyn_parameters.contact_parameters.enabled = tuple( enabled_collidable_points_mask.tolist() ) model.integrator = integrator assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4 # Initialize the maximum penetration of each collidable point at steady state. # This model is quasi-rigid, so we expect (almost) no penetration. max_penetration = 0.000 # Check jaxsim_model_box@conftest.py. box_height = 0.1 # Build the data of the model. data_t0 = js.data.JaxSimModelData.build( model=model, base_position=jnp.array([0.0, 0.0, box_height * 2]), velocity_representation=VelRepr.Inertial, ) # =========================================== # Run the simulation and test the final state # =========================================== data_tf = run_simulation(model=model, data_t0=data_t0, tf=1.0) # With this contact model, we need to slightly increase the tolerances. assert_allclose(data_tf.base_position[0:2], data_t0.base_position[0:2], atol=1e-5) assert_allclose( data_tf.base_position[2] + max_penetration, box_height / 2, atol=1e-4 ) def test_joint_limits( jaxsim_model_single_pendulum: js.model.JaxSimModel, ): model = jaxsim_model_single_pendulum with model.editable(validate=False) as model: model.kin_dyn_parameters.joint_parameters.position_limits_max = jnp.atleast_1d( jnp.array(1.5708) ) model.kin_dyn_parameters.joint_parameters.position_limits_min = jnp.atleast_1d( jnp.array(-1.5708) ) model.kin_dyn_parameters.joint_parameters.position_limit_spring = ( jnp.atleast_1d(jnp.array(75.0)) ) model.kin_dyn_parameters.joint_parameters.position_limit_damper = ( jnp.atleast_1d(jnp.array(0.1)) ) position_limits_min, position_limits_max = js.joint.position_limits(model=model) data = js.data.JaxSimModelData.build( model=model, velocity_representation=VelRepr.Inertial, ) theta = 10 * np.pi / 180 # Define a tolerance since the spring-damper model does # not guarantee that the joint position will be exactly # below the limit. tolerance = theta * 0.10 # Test minimum joint position limits. data_t0 = data.replace(model=model, joint_positions=position_limits_min - theta) model = model.replace(time_step=0.005, validate=False) data_tf = run_simulation(model=model, data_t0=data_t0, tf=3.0) assert ( np.min(np.array(data_tf.joint_positions), axis=0) + tolerance >= position_limits_min ) # Test maximum joint position limits. data_t0 = data.replace(model=model, joint_positions=position_limits_max - theta) model = model.replace(time_step=0.001) data_tf = run_simulation(model=model, data_t0=data_t0, tf=3.0) assert ( np.max(np.array(data_tf.joint_positions), axis=0) - tolerance <= position_limits_max ) @pytest.mark.parametrize( "initial_joint_positions", [ jnp.array([0, 0]), np.pi / 180 * jnp.array([5, 0]), ], ) def test_simulation_with_kinematic_constraints_double_pendulum( jaxsim_model_double_pendulum: js.model.JaxSimModel, initial_joint_positions: jtp.Array, ): # ======== # Arrange # ======== tf = 1.0 # Final simulation time in seconds. model = jaxsim_model_double_pendulum frame_1_name = "right_link_extremity_frame" frame_2_name = "left_link_extremity_frame" frame_1_idx = js.frame.name_to_idx(model=model, frame_name=frame_1_name) frame_2_idx = js.frame.name_to_idx(model=model, frame_name=frame_2_name) # Define the kinematic constraints. constraints = js.kin_dyn_parameters.ConstraintMap() constraints = constraints.add_constraint( model=model, frame_idx_1=frame_1_idx, frame_idx_2=frame_2_idx, constraint_type=ConstraintType.Weld, ) # Set the constraints in the model. with model.editable(validate=False) as model: model.kin_dyn_parameters.constraints = constraints model.gravity = 0.0 # Build the initial data for the model. data_t0 = js.data.JaxSimModelData.build( model=model, velocity_representation=VelRepr.Inertial, joint_positions=initial_joint_positions, ) # ==== # Act # ==== # Simulate the model for a given time and time step. data_tf = run_simulation(model=model, data_t0=data_t0, tf=tf) # ========= # Assert # ========= # Assert that the chosen frames exist in the model assert frame_1_name in model.frame_names() assert frame_2_name in model.frame_names() # Assert that the joint positions are now equal actual_delta_s_tf = jnp.abs(data_tf.joint_positions[0] - data_tf.joint_positions[1]) expected_delta_s_tf = 0.0 assert_allclose( expected_delta_s_tf, actual_delta_s_tf, atol=1e-2, err_msg=f"Position difference [deg]: {actual_delta_s_tf * 180 / np.pi}", ) def test_simulation_with_kinematic_constraints_cartpole( jaxsim_model_cartpole: js.model.JaxSimModel, ): # ======== # Arrange # ======== tf = 1.0 # Final simulation time in seconds. model = jaxsim_model_cartpole frame_1_name = "cart_frame" frame_2_name = "rail_frame" frame_1_idx = js.frame.name_to_idx(model=model, frame_name=frame_1_name) frame_2_idx = js.frame.name_to_idx(model=model, frame_name=frame_2_name) # Define the kinematic constraints. constraints = js.kin_dyn_parameters.ConstraintMap() constraints = constraints.add_constraint( model, frame_1_idx, frame_2_idx, ConstraintType.Weld, ) # Set the initial joint positions with the cart displaced from the rail zero position. initial_joint_positions = jnp.array([0.05, 0.0]) # Set the constraints in the model. with model.editable(validate=False) as model: model.kin_dyn_parameters.constraints = constraints # Build the initial data for the model. data_t0 = js.data.JaxSimModelData.build( model=model, velocity_representation=VelRepr.Inertial, joint_positions=initial_joint_positions, ) # ==== # Act # ==== # Simulate the model for a given time and time step. data_tf = run_simulation(model=model, data_t0=data_t0, tf=tf) H_frame1 = js.frame.transform( model=model, data=data_tf, frame_index=frame_1_idx, ) H_frame2 = js.frame.transform( model=model, data=data_tf, frame_index=frame_2_idx, ) # ========= # Assert # ========= # Assert that the chosen frames exist in the model assert frame_1_name in model.frame_names() assert frame_2_name in model.frame_names() # Assert that the two frames are in the same pose actual_frame_error = jnp.linalg.inv(H_frame1) @ H_frame2 expected_frame_error = jnp.eye(4) assert_allclose(actual_frame_error, expected_frame_error, atol=1e-3) def test_simulation_with_kinematic_constraints_4_bar_linkage( jaxsim_model_4_bar_linkage: js.model.JaxSimModel, ): """Test kinematic weld constraint on 4-bar linkage model.""" # ======== # Arrange # ======== tf = 1.0 # Final simulation time in seconds. model = jaxsim_model_4_bar_linkage frame_1_name = "BC1_frame" frame_2_name = "BC2_frame" frame_1_idx = js.frame.name_to_idx(model=model, frame_name=frame_1_name) frame_2_idx = js.frame.name_to_idx(model=model, frame_name=frame_2_name) # Define the kinematic constraints. constraints = js.kin_dyn_parameters.ConstraintMap() constraints = constraints.add_constraint( model=model, frame_idx_1=frame_1_idx, frame_idx_2=frame_2_idx, constraint_type=ConstraintType.Weld, K_P=1e4, ) # Set the constraints in the model. with model.editable(validate=False) as model: model.kin_dyn_parameters.constraints = constraints # Build the initial data for the model (default base pose is fine). data_t0 = js.data.JaxSimModelData.build( model=model, velocity_representation=VelRepr.Inertial, base_position=jnp.array([0.0, 0.0, 0.10]), ) # ==== # Act # ==== data_tf = run_simulation(model=model, data_t0=data_t0, tf=tf) H_frame1 = js.frame.transform( model=model, data=data_tf, frame_index=frame_1_idx, ) H_frame2 = js.frame.transform( model=model, data=data_tf, frame_index=frame_2_idx, ) # ========= # Assert # ========= assert frame_1_name in model.frame_names() assert frame_2_name in model.frame_names() # Position check pos1 = H_frame1[:3, 3] pos2 = H_frame2[:3, 3] assert_allclose(pos1, pos2, atol=1e-5) # Orientation check R1 = H_frame1[:3, :3] R2 = H_frame2[:3, :3] R_err = R1.T @ R2 assert_allclose(R_err, jnp.eye(3), atol=1e-3) ================================================ FILE: tests/test_visualizer.py ================================================ import pytest import rod from jaxsim.mujoco import ModelToMjcf from jaxsim.mujoco.loaders import MujocoCamera @pytest.fixture def mujoco_camera(): return MujocoCamera.build_from_target_view( camera_name="test_camera", lookat=(0, 0, 0), distance=1, azimuth=0, elevation=0, fovy=45, degrees=True, ) def test_urdf_loading(jaxsim_model_single_pendulum, mujoco_camera): model = jaxsim_model_single_pendulum.built_from _ = ModelToMjcf.convert(model=model, cameras=mujoco_camera) def test_sdf_loading(jaxsim_model_single_pendulum, mujoco_camera): model = rod.Sdf.load(sdf=jaxsim_model_single_pendulum.built_from).serialize( pretty=True ) _ = ModelToMjcf.convert(model=model, cameras=mujoco_camera) def test_rod_loading(jaxsim_model_single_pendulum, mujoco_camera): model = rod.Sdf.load(sdf=jaxsim_model_single_pendulum.built_from).models()[0] _ = ModelToMjcf.convert(model=model, cameras=mujoco_camera) def test_heightmap(jaxsim_model_single_pendulum, mujoco_camera): model = rod.Sdf.load(sdf=jaxsim_model_single_pendulum.built_from).models()[0] _ = ModelToMjcf.convert( model=model, cameras=mujoco_camera, heightmap=True, heightmap_samples_xy=(51, 51), ) def test_inclined_plane(jaxsim_model_single_pendulum, mujoco_camera): model = rod.Sdf.load(sdf=jaxsim_model_single_pendulum.built_from).models()[0] _ = ModelToMjcf.convert( model=model, cameras=mujoco_camera, plane_normal=(0.3, 0.3, 0.3), ) ================================================ FILE: tests/utils.py ================================================ from __future__ import annotations import dataclasses import pathlib import idyntree.bindings as idt import numpy as np import numpy.typing as npt import jaxsim.api as js from jaxsim import VelRepr def assert_allclose(actual, desired, rtol=1e-7, atol=1e-9, err_msg=""): """ Assert allclose with custom default tolerances. Normalizes only signed zeros using np.copysign. """ actual = np.asarray(actual, dtype=float) desired = np.asarray(desired, dtype=float) # Normalize zeros to avoid -0.0 vs 0.0 mismatches. actual = actual + 0.0 desired = desired + 0.0 np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, err_msg=err_msg) def build_kindyncomputations_from_jaxsim_model( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, considered_joints: list[str] | None = None, removed_joint_positions: dict[str, npt.NDArray | float | int] | None = None, ) -> KinDynComputations: """ Build a `KinDynComputations` from `JaxSimModel` and `JaxSimModelData`. Args: model: The `JaxSimModel` from which to build the `KinDynComputations`. data: The `JaxSimModelData` from which to build the `KinDynComputations`. considered_joints: The list of joint names to consider in the `KinDynComputations`. removed_joint_positions: A dictionary defining the positions of the removed joints (default is 0). Returns: The `KinDynComputations` built from the `JaxSimModel` and `JaxSimModelData`. Note: Only `JaxSimModel` built from URDF files are supported. """ if ( isinstance(model.built_from, pathlib.Path) and model.built_from.suffix != ".urdf" ) or (isinstance(model.built_from, str) and " KinDynComputations: """ Store the state of a `JaxSimModelData` in `KinDynComputations`. Args: data: The `JaxSimModelData` providing the desired state to copy. kin_dyn: The `KinDynComputations` in which to store the state of `JaxSimModelData`. Returns: The updated `KinDynComputations` with the state of `JaxSimModelData`. """ if kin_dyn.dofs() != data.joint_positions.size: raise ValueError(data) with data.switch_velocity_representation(kin_dyn.vel_repr): kin_dyn.set_robot_state( joint_positions=np.array(data.joint_positions), joint_velocities=np.array(data.joint_velocities), base_transform=np.array(data._base_transform), base_velocity=np.array(data.base_velocity), ) return kin_dyn @dataclasses.dataclass class KinDynComputations: """High-level wrapper of the iDynTree KinDynComputations class.""" vel_repr: VelRepr gravity: npt.NDArray kin_dyn: idt.KinDynComputations @staticmethod def build( urdf: pathlib.Path | str, considered_joints: list[str] | None = None, vel_repr: VelRepr = VelRepr.Inertial, gravity: npt.NDArray = np.array([0, 0, -10.0]), removed_joint_positions: dict[str, npt.NDArray | float | int] | None = None, ) -> KinDynComputations: # Read the URDF description. urdf_string = urdf.read_text() if isinstance(urdf, pathlib.Path) else urdf # Create the model loader. mdl_loader = idt.ModelLoader() # Handle removed_joint_positions if None. removed_joint_positions = ( {name: float(pos) for name, pos in removed_joint_positions.items()} if removed_joint_positions is not None else {} ) # Load the URDF description. if not ( mdl_loader.loadModelFromString(urdf_string) if considered_joints is None else mdl_loader.loadReducedModelFromString( urdf_string, considered_joints, removed_joint_positions ) ): raise RuntimeError("Failed to load URDF description") # Create KinDynComputations and insert the model. kindyn = idt.KinDynComputations() if not kindyn.loadRobotModel(mdl_loader.model()): raise RuntimeError("Failed to load model") vel_repr_to_idyntree = { VelRepr.Inertial: idt.INERTIAL_FIXED_REPRESENTATION, VelRepr.Body: idt.BODY_FIXED_REPRESENTATION, VelRepr.Mixed: idt.MIXED_REPRESENTATION, } # Configure the frame representation. if not kindyn.setFrameVelocityRepresentation(vel_repr_to_idyntree[vel_repr]): raise RuntimeError("Failed to set the frame representation") return KinDynComputations( kin_dyn=kindyn, vel_repr=vel_repr, gravity=np.array(gravity).squeeze(), ) def set_robot_state( self, joint_positions: npt.NDArray | None = None, joint_velocities: npt.NDArray | None = None, base_transform: npt.NDArray = np.eye(4), base_velocity: npt.NDArray = np.zeros(6), world_gravity: npt.NDArray | None = None, ) -> None: joint_positions = ( joint_positions if joint_positions is not None else np.zeros(self.dofs()) ) joint_velocities = ( joint_velocities if joint_velocities is not None else np.zeros(self.dofs()) ) gravity = world_gravity if world_gravity is not None else self.gravity if joint_positions.size != self.dofs(): raise ValueError(joint_positions.size, self.dofs()) if joint_velocities.size != self.dofs(): raise ValueError(joint_velocities.size, self.dofs()) if gravity.size != 3: raise ValueError(gravity.size, 3) if base_transform.shape != (4, 4): raise ValueError(base_transform.shape, (4, 4)) if base_velocity.size != 6: raise ValueError(base_velocity.size) g = idt.Vector3().FromPython(np.array(gravity)) s = idt.VectorDynSize().FromPython(np.array(joint_positions)) s_dot = idt.VectorDynSize().FromPython(np.array(joint_velocities)) p = idt.Position(*[float(i) for i in np.array(base_transform[0:3, 3])]) R = idt.Rotation() R = R.FromPython(np.array(base_transform[0:3, 0:3])) world_H_base = idt.Transform() world_H_base.setPosition(p) world_H_base.setRotation(R) v_WB = idt.Twist().FromPython(base_velocity) if not self.kin_dyn.setRobotState(world_H_base, s, v_WB, s_dot, g): raise RuntimeError("Failed to set the robot state") # Update stored gravity. self.gravity = gravity def dofs(self) -> int: return self.kin_dyn.getNrOfDegreesOfFreedom() def joint_names(self) -> list[str]: model: idt.Model = self.kin_dyn.model() return [model.getJointName(i) for i in range(model.getNrOfJoints())] def link_names(self) -> list[str]: return [ self.kin_dyn.getFrameName(i) for i in range(self.kin_dyn.getNrOfLinks()) ] def frame_names(self) -> list[str]: return [ self.kin_dyn.getFrameName(i) for i in range(self.kin_dyn.getNrOfLinks(), self.kin_dyn.getNrOfFrames()) ] def joint_positions(self) -> npt.NDArray: vector = idt.VectorDynSize() if not self.kin_dyn.getJointPos(vector): raise RuntimeError("Failed to extract joint positions") return vector.toNumPy() def joint_velocities(self) -> npt.NDArray: vector = idt.VectorDynSize() if not self.kin_dyn.getJointVel(vector): raise RuntimeError("Failed to extract joint velocities") return vector.toNumPy() def jacobian_frame(self, frame_name: str) -> npt.NDArray: if self.kin_dyn.getFrameIndex(frame_name) < 0: raise ValueError(f"Frame '{frame_name}' does not exist") J = idt.MatrixDynSize(6, 6 + self.dofs()) if not self.kin_dyn.getFrameFreeFloatingJacobian(frame_name, J): raise RuntimeError("Failed to get the frame free-floating jacobian") return J.toNumPy() def total_mass(self) -> float: model: idt.Model = self.kin_dyn.model() return model.getTotalMass() def link_spatial_inertia(self, link_name: str) -> npt.NDArray: if link_name not in self.link_names(): raise ValueError(link_name) model = self.kin_dyn.model() link: idt.Link = model.getLink(model.getLinkIndex(link_name)) return link.inertia().asMatrix().toNumPy() def link_mass(self, link_name: str) -> float: if link_name not in self.link_names(): raise ValueError(link_name) model = self.kin_dyn.model() link: idt.Link = model.getLink(model.getLinkIndex(link_name)) return link.getInertia().asVector().toNumPy()[0] def floating_base_frame(self) -> str: return self.kin_dyn.getFloatingBase() def frame_transform(self, frame_name: str) -> npt.NDArray: if self.kin_dyn.getFrameIndex(frame_name) < 0: raise ValueError(f"Frame '{frame_name}' does not exist") if frame_name == self.floating_base_frame(): H_idt = self.kin_dyn.getWorldBaseTransform() else: H_idt = self.kin_dyn.getWorldTransform(frame_name) H = np.eye(4) H[0:3, 3] = H_idt.getPosition().toNumPy() H[0:3, 0:3] = H_idt.getRotation().toNumPy() return H def frame_relative_transform( self, ref_frame_name: str, frame_name: str ) -> npt.NDArray: if self.kin_dyn.getFrameIndex(ref_frame_name) < 0: raise ValueError(f"Frame '{ref_frame_name}' does not exist") if self.kin_dyn.getFrameIndex(frame_name) < 0: raise ValueError(f"Frame '{frame_name}' does not exist") ref_H_frame: idt.Transform = self.kin_dyn.getRelativeTransform( ref_frame_name, frame_name ) H = np.eye(4) H[0:3, 3] = ref_H_frame.getPosition().toNumPy() H[0:3, 0:3] = ref_H_frame.getRotation().toNumPy() return H def frame_parent_link_name(self, frame_name: str) -> str: return self.kin_dyn.model().getLinkName( self.kin_dyn.model().getFrameLink( self.kin_dyn.model().getFrameIndex(frame_name) ) ) def base_velocity(self) -> npt.NDArray: nu = idt.VectorDynSize() if not self.kin_dyn.getModelVel(nu): raise RuntimeError("Failed to get the model velocity") return nu.toNumPy()[0:6] def frame_velocity(self, frame_name: str) -> npt.NDArray: if self.kin_dyn.getFrameIndex(frame_name) < 0: raise ValueError(f"Frame '{frame_name}' does not exist") v_WF = self.kin_dyn.getFrameVel(frame_name) return v_WF.toNumPy() def frame_bias_acc(self, frame_name: str) -> npt.NDArray: if self.kin_dyn.getFrameIndex(frame_name) < 0: raise ValueError(f"Frame '{frame_name}' does not exist") J̇ν = self.kin_dyn.getFrameBiasAcc(frame_name) return J̇ν.toNumPy() def com_position(self) -> npt.NDArray: W_p_G = self.kin_dyn.getCenterOfMassPosition() return W_p_G.toNumPy() def com_velocity(self) -> npt.NDArray: W_ṗ_G = self.kin_dyn.getCenterOfMassVelocity() return W_ṗ_G.toNumPy() def com_bias_acceleration(self) -> npt.NDArray: return self.kin_dyn.getCenterOfMassBiasAcc().toNumPy() def mass_matrix(self) -> npt.NDArray: M = idt.MatrixDynSize() if not self.kin_dyn.getFreeFloatingMassMatrix(M): raise RuntimeError("Failed to get the free floating mass matrix") return M.toNumPy() def bias_forces(self) -> npt.NDArray: h = idt.FreeFloatingGeneralizedTorques(self.kin_dyn.model()) if not self.kin_dyn.generalizedBiasForces(h): raise RuntimeError("Failed to get the generalized bias forces") base_wrench: idt.Wrench = h.baseWrench() joint_torques: idt.JointDOFsDoubleArray = h.jointTorques() return np.hstack( [base_wrench.toNumPy().flatten(), joint_torques.toNumPy().flatten()] ) def gravity_forces(self) -> npt.NDArray: g = idt.FreeFloatingGeneralizedTorques(self.kin_dyn.model()) if not self.kin_dyn.generalizedGravityForces(g): raise RuntimeError("Failed to get the generalized gravity forces") base_wrench: idt.Wrench = g.baseWrench() joint_torques: idt.JointDOFsDoubleArray = g.jointTorques() return np.hstack( [base_wrench.toNumPy().flatten(), joint_torques.toNumPy().flatten()] ) def total_momentum(self) -> npt.NDArray: return self.kin_dyn.getLinearAngularMomentum().toNumPy().flatten() def centroidal_momentum(self) -> npt.NDArray: return self.kin_dyn.getCentroidalTotalMomentum().toNumPy().flatten() def total_momentum_jacobian(self) -> npt.NDArray: Jh = idt.MatrixDynSize() if not self.kin_dyn.getLinearAngularMomentumJacobian(Jh): raise RuntimeError("Failed to get the total momentum jacobian") return Jh.toNumPy() def centroidal_momentum_jacobian(self) -> npt.NDArray: Jh = idt.MatrixDynSize() if not self.kin_dyn.getCentroidalTotalMomentumJacobian(Jh): raise RuntimeError("Failed to get the centroidal momentum jacobian") return Jh.toNumPy() def locked_spatial_inertia(self) -> npt.NDArray: return self.kin_dyn.getRobotLockedInertia().asMatrix().toNumPy() def locked_centroidal_spatial_inertia(self) -> npt.NDArray: return self.kin_dyn.getCentroidalRobotLockedInertia().asMatrix().toNumPy() def average_velocity(self) -> npt.NDArray: return self.kin_dyn.getAverageVelocity().toNumPy() def average_velocity_jacobian(self) -> npt.NDArray: Jh = idt.MatrixDynSize() if not self.kin_dyn.getAverageVelocityJacobian(Jh): raise RuntimeError("Failed to get the average velocity jacobian") return Jh.toNumPy() def average_centroidal_velocity(self) -> npt.NDArray: return self.kin_dyn.getCentroidalAverageVelocity().toNumPy() def average_centroidal_velocity_jacobian(self) -> npt.NDArray: Jh = idt.MatrixDynSize() if not self.kin_dyn.getCentroidalAverageVelocityJacobian(Jh): raise RuntimeError("Failed to get the average centroidal velocity jacobian") return Jh.toNumPy()