[
  {
    "path": ".devcontainer/Dockerfile",
    "content": "# syntax=docker/dockerfile:1.4\nFROM mcr.microsoft.com/devcontainers/base:jammy\n\nARG PROJECT_NAME=jaxsim\nARG PIXI_VERSION=v0.35.0\n\nRUN curl -o /usr/local/bin/pixi -SL https://github.com/prefix-dev/pixi/releases/download/${PIXI_VERSION}/pixi-$(uname -m)-unknown-linux-musl \\\n    && chmod +x /usr/local/bin/pixi \\\n    && pixi info\n\n# Add LFS repository and install.\nRUN apt-get update && apt-get install -y curl \\\n    && curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash \\\n    && apt install -y git-lfs\n\nUSER vscode\nWORKDIR /home/vscode\n\nRUN echo 'eval \"$(pixi completion -s bash)\"' >> /home/vscode/.bashrc\n"
  },
  {
    "path": ".devcontainer/devcontainer.json",
    "content": "// For format details, see https://aka.ms/devcontainer.json. For config options, see the\n// README at: https://github.com/devcontainers/templates/tree/main/src/ubuntu\n{\n\t\"name\": \"Ubuntu\",\n\t// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile\n\t\"build\": {\n\t\t\"context\": \"..\",\n\t\t\"dockerfile\": \"Dockerfile\"\n\t},\n\n\t// Features to add to the dev container. More info: https://containers.dev/features.\n\t\"features\": {\n        \"ghcr.io/devcontainers/features/docker-in-docker:2\": {}\n    },\n\n    // Put `.pixi` folder in a mounted volume of a case-insensitive filesystem.\n    \"mounts\": [\"source=${localWorkspaceFolderBasename}-pixi,target=${containerWorkspaceFolder}/.pixi,type=volume\"],\n\n\t// Use 'forwardPorts' to make a list of ports inside the container available locally.\n\t// \"forwardPorts\": [],\n\n\t// Use 'postCreateCommand' to run commands after the container is created.\n\t\"postCreateCommand\": \"sudo chown vscode .pixi && git lfs pull --include='pixi.lock' && pixi install --environment=test-cpu\",\n\n\t// Configure tool-specific properties.\n\t// \"customizations\": {},\n\n\t// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.\n\t// \"remoteUser\": \"root\"\n\n\t// VSCode extensions\n\t\"customizations\": {\n\t\t\"vscode\": {\n\t\t\t\"settings\": {\n                \t\t\"python.pythonPath\": \"/workspaces/jaxsim/.pixi/envs/test-cpu/bin/python\",\n\t\t\t\t\"python.defaultInterpreterPath\": \"/workspaces/jaxsim/.pixi/envs/test-cpu/bin/python\",\n\t\t\t\t\"python.terminal.activateEnvironment\": true,\n\t\t\t\t\"python.terminal.activateEnvInCurrentTerminal\": true\n\t\t\t},\n\t\t\t\"extensions\": [\n\t\t\t\t\"ms-python.python\",\n\t\t\t\t\"donjayamanne.python-extension-pack\",\n\t\t\t\t\"ms-toolsai.jupyter\",\n\t\t\t\t\"GitHub.codespaces\",\n                \t\t\"GitHub.copilot\",\n\t\t\t\t\"ms-azuretools.vscode-docker\",\n                \t\t\"charliermarsh.ruff\"\n\t\t\t]\n\t\t}\n\t}\n}\n"
  },
  {
    "path": ".gitattributes",
    "content": "# GitHub syntax highlighting\npixi.lock filter=lfs diff=lfs merge=lfs -text\n"
  },
  {
    "path": ".github/CODEOWNERS",
    "content": "*       @flferretti\n"
  },
  {
    "path": ".github/dependabot.yml",
    "content": "version: 2\nupdates:\n\n  # Check for updates to GitHub Actions every month.\n  - package-ecosystem: github-actions\n    directory: /\n    schedule:\n      interval: monthly\n    # Disable rebasing automatically existing pull requests.\n    rebase-strategy: \"disabled\"\n    # Group updates to a single PR.\n    groups:\n      dependencies:\n        patterns:\n          - '*'\n"
  },
  {
    "path": ".github/release.yml",
    "content": "changelog:\n  exclude:\n    authors:\n      - dependabot[bot]\n      - pre-commit-ci[bot]\n      - github-actions[bot]\n"
  },
  {
    "path": ".github/workflows/ci_cd.yml",
    "content": "name: Python CI/CD\n\non:\n  workflow_dispatch:\n  push:\n  pull_request:\n  release:\n    types:\n      - published\n  schedule:\n  # Execute a nightly build at 2am UTC.\n  - cron:  '0 2 * * *'\n\n\njobs:\n\n  package:\n    name: Package the project\n    runs-on: ubuntu-latest\n\n    steps:\n\n      - uses: actions/checkout@v6\n        with:\n          fetch-depth: 0\n\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.11\"\n\n      - name: Install Python tools\n        run: pip install build twine\n\n      - name: Create distributions\n        run: python -m build -o dist/\n\n      - name: Inspect dist folder\n        run: ls -lah dist/\n\n      - name: Check wheel's abi and platform tags\n        run: test $(find dist/ -name *-none-any.whl | wc -l) -gt 0\n\n      - name: Run twine check\n        run: twine check dist/*\n\n      - name: Upload artifacts\n        uses: actions/upload-artifact@v7\n        with:\n          path: dist/*\n          name: dist\n\n  test:\n    name: 'Python${{ matrix.python }}@${{ matrix.os }}'\n    needs: package\n    runs-on: ${{ matrix.os }}\n    env:\n      PYTHONUTF8: \"1\"\n    strategy:\n      fail-fast: false\n      matrix:\n        os:\n          - ubuntu-latest\n          - macos-latest\n          - windows-latest\n        python:\n          - \"3.10\"\n          - \"3.11\"\n          - \"3.12\"\n          - \"3.13\"\n\n    steps:\n\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: ${{ matrix.python }}\n\n      - name: Download Python packages\n        uses: actions/download-artifact@v8\n        with:\n          path: dist\n          name: dist\n\n      - name: Install wheel (ubuntu)\n        if: contains(matrix.os, 'ubuntu')\n        shell: bash\n        run: pip install \"$(find dist/ -type f -name '*.whl')\"\n\n      - name: Install wheel (macos|windows)\n        if: contains(matrix.os, 'macos') || contains(matrix.os, 'windows')\n        shell: bash\n        run: pip install \"$(find dist/ -type f -name '*.whl')\"\n\n      - name: Document installed pip packages\n        shell: bash\n        run: pip list --verbose\n\n      - name: Import the package\n        run: python -c \"import jaxsim\"\n\n      - uses: actions/checkout@v6\n        with:\n          lfs: true\n\n      - uses: prefix-dev/setup-pixi@v0.9.5\n        if: contains(matrix.os, 'ubuntu')\n        with:\n          pixi-version: \"latest\"\n          frozen: true\n          cache: true\n          cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}\n\n      - name: Run the Python tests\n        if: |\n          contains(matrix.os, 'ubuntu') &&\n          (github.event_name != 'pull_request')\n        run: pixi run --frozen test --numprocesses auto\n        env:\n          # https://github.com/pytest-dev/pytest/issues/7443#issuecomment-656642591\n          PY_COLORS: \"1\"\n          JAX_PLATFORM_NAME: cpu\n\n  publish:\n    name: Publish to PyPI\n    needs: test\n    runs-on: ubuntu-latest\n    permissions:\n        id-token: write\n\n    steps:\n\n      - name: Download Python packages\n        uses: actions/download-artifact@v8\n        with:\n          path: dist\n          name: dist\n\n      - name: Inspect dist folder\n        run: ls -lah dist/\n\n      - name: Publish to PyPI\n        if: |\n          github.repository == 'gbionics/jaxsim' &&\n          ((github.event_name == 'push' && github.ref == 'refs/heads/main') ||\n           (github.event_name == 'release'))\n        uses: pypa/gh-action-pypi-publish@release/v1\n        with:\n          skip-existing: true\n"
  },
  {
    "path": ".github/workflows/gpu_benchmark.yml",
    "content": "name: GPU Benchmarks\n\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    types: [opened, reopened, synchronize]\n  workflow_dispatch:\n  schedule:\n    - cron: \"0 0  * * 1\" # Run At 00:00 on Monday\n\npermissions:\n  pull-requests: write\n  deployments: write\n  contents: write\n\njobs:\n  benchmark:\n    runs-on: self-hosted\n    container:\n      image: ghcr.io/prefix-dev/pixi:0.46.0-noble@sha256:c12bcbe8ba5dfd71867495d3471b95a6993b79cc7de7eafec016f8f59e4e4961\n      options: --rm --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -e \"TERM=xterm-256color\"\n\n    steps:\n      - name: Install Git and Git-LFS\n        run: |\n            apt update && apt install -y git git-lfs\n\n      - name: Checkout repository\n        uses: actions/checkout@v6\n        with:\n          lfs: true\n          fetch-depth: 0\n\n      - name: Fetch pixi.lock from LFS\n        run: |\n          git config --global safe.directory /__w/jaxsim/jaxsim\n          git lfs checkout pixi.lock\n\n      - name: Get main branch SHA\n        id: get-main-branch-sha\n        run: |\n          SHA=$(git rev-parse origin/main)\n          echo \"sha=$SHA\" >> $GITHUB_OUTPUT\n\n      - name: Get benchmark results from main branch\n        id: cache\n        uses: actions/cache/restore@v5\n        with:\n          path: ./cache\n          key: ${{ runner.os }}-benchmark\n\n      - name: Run benchmark and store result\n        run: |\n            pixi run --frozen --environment gpu benchmark --batch-size 128 --benchmark-json output.json\n        env:\n            PY_COLORS: \"1\"\n\n      - name: Compare benchmark results with main branch\n        uses: benchmark-action/github-action-benchmark@v1.22.0\n        with:\n          tool: 'pytest'\n          output-file-path: output.json\n          external-data-json-path: ./cache/benchmark-data.json\n          save-data-file: false\n          fail-on-alert: true\n          summary-always: true\n          comment-always: true\n          alert-threshold: 150%\n          github-token: ${{ secrets.GITHUB_TOKEN }}\n\n      - name: Store benchmark result for main branch\n        uses: benchmark-action/github-action-benchmark@v1.22.0\n        if: ${{ github.ref_name == 'main' }}\n        with:\n          tool: 'pytest'\n          output-file-path: output.json\n          external-data-json-path: ./cache/benchmark-data.json\n          save-data-file: true\n          fail-on-alert: false\n          summary-always: true\n          comment-always: true\n          alert-threshold: 150%\n          github-token: ${{ secrets.GITHUB_TOKEN }}\n\n      - name: Publish Benchmark Results to GitHub Pages\n        uses: benchmark-action/github-action-benchmark@v1.22.0\n        if: ${{ github.ref_name == 'main' }}\n        with:\n          tool: 'pytest'\n          output-file-path: output.json\n          benchmark-data-dir-path: \"benchmarks\"\n          fail-on-alert: false\n          github-token: ${{ secrets.GITHUB_TOKEN }}\n          comment-on-alert: true\n          summary-always: true\n          save-data-file: true\n          alert-threshold: \"150%\"\n          auto-push: ${{ github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' }}\n\n      - name: Update Benchmark Results cache\n        uses: actions/cache/save@v5\n        if: ${{ github.ref_name == 'main' }}\n        with:\n          path: ./cache\n          key: ${{ runner.os }}-benchmark\n"
  },
  {
    "path": ".github/workflows/pixi.yml",
    "content": "name: Pixi\n\npermissions:\n  contents: write\n  pull-requests: write\n\non:\n  workflow_dispatch:\n  schedule:\n    # Execute at 5am UTC on the first day of the month.\n    - cron: '0 5 1 * *'\n\njobs:\n\n  pixi-update:\n    runs-on: ubuntu-24.04\n\n    steps:\n\n      - uses: actions/checkout@v6\n        with:\n          lfs: true\n\n      - name: Set up pixi\n        uses: prefix-dev/setup-pixi@v0.9.5\n        with:\n          run-install: false\n\n      - name: Install pixi-diff-to-markdown\n        run: pixi global install pixi-diff-to-markdown\n\n      - name: Update pixi lockfile and generate diff\n        run: |\n          set -o pipefail\n          pixi update --json | pixi exec pixi-diff-to-markdown --explicit-column > diff.md\n\n      - name: Test project against updated pixi\n        run: pixi run --environment default test\n        env:\n          PY_COLORS: \"1\"\n          JAX_PLATFORM_NAME: cpu\n\n      - name: Commit and push changes\n        run: echo \"BRANCH_NAME=update-pixi-$(date +'%Y%m%d%H%M%S')\" >> $GITHUB_ENV\n\n      - name: Create pull request\n        uses: peter-evans/create-pull-request@v8\n        with:\n          token: ${{ secrets.GITHUB_TOKEN }}\n          commit-message: Update `pixi.lock`\n          title: Update `pixi` lockfile\n          body-path: diff.md\n          branch: ${{ env.BRANCH_NAME }}\n          base: main\n          labels: pixi\n          add-paths: pixi.lock\n          delete-branch: true\n          committer: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>\n          author: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>\n"
  },
  {
    "path": ".github/workflows/read_the_docs.yml",
    "content": "name: Read the Docs PR\non:\n  pull_request_target:\n    types:\n      - opened\n\npermissions:\n  pull-requests: write\n\njobs:\n\n  documentation-links:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: readthedocs/actions/preview@v1\n        with:\n          project-slug: \"jaxsim\"\n          project-language: \"\"\n"
  },
  {
    "path": ".gitignore",
    "content": "# IDEs\n.idea*\n.vscode/\n\n# Matlab\n*.m~\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\ndocs/_collections/\ndocs/modules/_autosummary/\ndocs/modules/generated\ndocs/sg_execution_times.rst\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# dynamic version\nsrc/jaxsim/_version.py\n\n# ruff\n.ruff_cache/\n\n# pixi environments\n.pixi\n\n# data\n.mp4\n.png\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "ci:\n  autofix_prs: false\n  autoupdate_schedule: quarterly\n  submodules: false\n\ndefault_language_version:\n  python: python3\n\nrepos:\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v6.0.0\n    hooks:\n    - id: check-ast\n    - id: check-merge-conflict\n    - id: check-yaml\n    - id: end-of-file-fixer\n    - id: trailing-whitespace\n    - id: check-toml\n    - id: check-added-large-files\n      args: [\"--maxkb=2000\"]\n\n  - repo: https://github.com/psf/black-pre-commit-mirror\n    rev: 26.3.1\n    hooks:\n      - id: black\n        args: [\"--check\", \"--diff\"]\n\n  - repo: https://github.com/pycqa/isort\n    rev: 8.0.1\n    hooks:\n      - id: isort\n        args: [\"--check\", \"--diff\"]\n\n  - repo: https://github.com/pre-commit/pygrep-hooks\n    rev: v1.10.0\n    hooks:\n      - id: rst-backticks\n      - id: rst-directive-colons\n      - id: rst-inline-touching-normal\n\n  - repo: https://github.com/codespell-project/codespell\n    rev: v2.4.2\n    hooks:\n      - id: codespell\n        args: [\"-S\", \"*.lock\"]\n\n  - repo: https://github.com/astral-sh/ruff-pre-commit\n    rev: v0.15.9\n    hooks:\n      - id: ruff\n\n  - repo: https://github.com/kynan/nbstripout\n    rev: 0.9.1\n    hooks:\n      - id: nbstripout\n"
  },
  {
    "path": ".readthedocs.yaml",
    "content": "version: \"2\"\n\nbuild:\n  os: ubuntu-24.04\n  tools:\n    python: \"mambaforge-23.11\"\n\nconda:\n  environment: environment.yml\n\npython:\n  install:\n    - method: pip\n      path: .\n\nsphinx:\n  configuration: docs/conf.py\n\nformats: all\n"
  },
  {
    "path": "CITATION.cff",
    "content": "cff-version: 1.2.0\ntitle: JaxSim\nmessage: \"If you use this software, please cite the paper.\"\ntype: software\nauthors:\n  - family-names: Ferretti\n    given-names: Filippo Luca\n    affiliation: \"Generative Bionics\"\n  - family-names: Ferigo\n    given-names: Diego\n    affiliation: \"Robotics & AI Institute\"\n  - family-names: Croci\n    given-names: Alessandro\n    affiliation: \"NEURA Robotics\"\n  - family-names: Sartore\n    given-names: Carlotta\n    affiliation: \"Generative Bionics\"\n  - family-names: Younis\n    given-names: Omar G.\n    affiliation: \"Quebec AI Institute\"\n  - family-names: Traversaro\n    given-names: Silvio\n    affiliation: \"Generative Bionics\"\n  - family-names: Pucci\n    given-names: Daniele\n    affiliation: \"Generative Bionics\"\nrepository-code: \"https://github.com/gbionics/jaxsim\"\nlicense: BSD-3-Clause\npreferred-citation:\n  type: article\n  title: \"Contact-Aware Morphology Optimization via Physically Consistent Differentiable Simulation\"\n  authors:\n    - family-names: Ferretti\n      given-names: Filippo Luca\n    - family-names: Ferigo\n      given-names: Diego\n    - family-names: Croci\n      given-names: Alessandro\n    - family-names: Sartore\n      given-names: Carlotta\n    - family-names: Younis\n      given-names: Omar G.\n    - family-names: Traversaro\n      given-names: Silvio\n    - family-names: Pucci\n      given-names: Daniele\n  journal: \"IEEE Robotics and Automation Letters\"\n  year: 2026\n  doi: \"10.1109/LRA.2026.3678125\"\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing to JAXsim :rocket:\n\nHello Contributor,\n\nWe're thrilled that you're considering contributing to JAXsim!\nHere's a brief guide to help you seamlessly become a part of our project.\n\n## Development Environment :hammer_and_wrench:\n\nMake sure your development environment is set up.\nFollow the installation instructions in the [README](./README.md) to get JAXsim and its dependencies up and running.\n\nTo ensure consistency and maintain code quality, we recommend using the pre-commit hook with the following configuration.\nThis will help catch issues before they become a part of the project.\n\n### Setting Up Pre-commit Hook :fishing_pole_and_fish:\n\n`pre-commit` is a tool that manages pre-commit hooks for your project.\nIt will run checks on your code before you commit it, ensuring that it meets the project's standards.\n\nFirst, install `pre-commit` if you haven't already:\n\n```bash\npip install pre-commit\n```\n\nThen, run the following command to install the hooks:\n\n```bash\npre-commit install\n```\n\n### Using Pre-commit Hook :vertical_traffic_light:\n\nBefore making any commits, the pre-commit hook will automatically run.\nIf it finds any issues, it will prevent the commit and provide instructions on how to fix them.\n\nTo get your commit through without fixing the issues, use the `--no-verify` flag:\n\n```bash\ngit commit -m \"Your commit message\" --no-verify\n```\n\nTo manually run the pre-commit hook at any time, use:\n\n```bash\npre-commit run --all-files\n```\n\n## Making Changes :construction:\n\nBefore submitting a pull request, create an issue to discuss your changes if major changes are involved.\nThis helps us understand your needs and provide feedback.\nClearly describe your pull request, referencing any related issues.\nFollow the [PEP 8](https://peps.python.org/pep-0008/) style guide and include relevant tests.\n\n## Testing :test_tube:\n\nYour code will be tested with the CI/CD pipeline before merging.\nFeel free to add new ones or update the existing tests in the [workflows](./.github/workflows) folder to cover your changes.\n\n## Documentation :book:\n\nUpdate the documentation in the [docs](./docs) folder and the [README](./README.md) to reflect your changes, if necessary.\nThere is no need to build the documentation locally; it will be automatically built and deployed with your pull request, where a preview link will be provided.\n\n## Code Review :eyes:\n\nExpect feedback during the code review process.\nAddress comments and make necessary changes.\nThis collaboration ensures quality.\nPlease keep the commit history clean, or squash commits if necessary.\n\n## License :scroll:\n\nJAXsim is under the [BSD 3-Clause License](./LICENSE).\nBy contributing, you agree to the same license.\n\nThank you for contributing to JAXsim! Your efforts are appreciated.\n"
  },
  {
    "path": "LICENSE",
    "content": "BSD 3-Clause License\n\nCopyright (c) 2022, Artificial and Mechanical Intelligence\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n1. Redistributions of source code must retain the above copyright notice, this\n   list of conditions and the following disclaimer.\n\n2. Redistributions in binary form must reproduce the above copyright notice,\n   this list of conditions and the following disclaimer in the documentation\n   and/or other materials provided with the distribution.\n\n3. Neither the name of the copyright holder nor the names of its\n   contributors may be used to endorse or promote products derived from\n   this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "README.md",
    "content": "# JaxSim\n\n**JaxSim** is a **differentiable physics engine** built with JAX, tailored for co-design and robotic learning applications.\n\n<div align=\"center\">\n<br/>\n<table>\n  <tr>\n    <th><img src=\"https://github.com/user-attachments/assets/89d0b4ca-7e0c-4f58-bf3e-9540e35b9a01\" style=\"height:300px; width:400px; object-fit:cover;\"></th>\n    <th><img src=\"https://github.com/user-attachments/assets/a909e388-d7b4-4b58-89f1-035da8636d94\" style=\"height:300px; width:400px; object-fit:cover;\"></th>\n  </tr>\n  <tr>\n    <th><img src=\"https://github.com/user-attachments/assets/3692bc06-18ed-406d-80bd-480780346224\" style=\"height:300px; width:400px; object-fit:cover;\"></th>\n    <th><img src=\"https://github.com/user-attachments/assets/3356f332-4710-4946-9a82-a8c2305dab88\" style=\"height:300px; width:400px; object-fit:cover;\"></th>\n  </tr>\n</table>\n<br/>\n</div>\n\n## Features\n\n- Physically consistent differentiability w.r.t. hardware parameters.\n- Closed chain dynamics support.\n- Reduced-coordinate physics engine for **fixed-base** and **floating-base** robots.\n- Fully Python-based, leveraging [JAX][jax] following a functional programming paradigm.\n- Seamless execution on CPUs, GPUs, and TPUs.\n- Supports JIT compilation and automatic vectorization for high performance.\n- Compatible with SDF models and URDF (via [sdformat][sdformat] conversion).\n\n> [!WARNING]\n> This project is still experimental. APIs may change between releases without notice.\n\n> [!NOTE]\n> JaxSim currently focuses on locomotion applications.\n> Only contacts between bodies and smooth ground surfaces are supported.\n\n## How to use it\n\n```python\nimport pathlib\n\nimport icub_models\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\n\n# Load the iCub model\nmodel_path = icub_models.get_model_file(\"iCubGazeboV2_5\")\n\njoints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',\n          'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',\n          'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',\n          'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',\n          'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch',\n          'r_ankle_roll')\n\n# Build and reduce the model\nmodel_description = pathlib.Path(model_path)\n\nfull_model = js.model.JaxSimModel.build_from_model_description(\n    model_description=model_description,\n)\n\nmodel = js.model.reduce(model=full_model, considered_joints=joints)\n\n# Get the number of degrees of freedom\nndof = model.dofs()\n\n# Initialize data and simulation\n# Note that the default data representation is mixed velocity representation\ndata = js.data.JaxSimModelData.build(\n    model=model, base_position=jnp.array([0.0, 0.0, 1.0])\n)\n\nT = jnp.arange(start=0, stop=1.0, step=model.time_step)\n\ntau = jnp.zeros(ndof)\n\n# Simulate\nfor _ in T:\n    data = js.model.step(\n        model=model, data=data, link_forces=None, joint_force_references=tau\n    )\n```\n\nCheck the example folder for additional use cases!\n\n[jax]: https://github.com/google/jax/\n[sdformat]: https://github.com/gazebosim/sdformat\n[notation]: https://research.tue.nl/en/publications/multibody-dynamics-notation-version-2\n[passive_viewer_mujoco]: https://mujoco.readthedocs.io/en/stable/python.html#passive-viewer\n\n## Installation\n\n<details>\n<summary>With <code>conda</code></summary>\n\nYou can install the project using [`conda`][conda] as follows:\n\n```bash\nconda install jaxsim -c conda-forge\n```\n\nGPU support for JAX will be automatically installed if a compatible GPU is detected.\n\n</details>\n\n<details>\n<summary>With <code>pixi</code></summary>\n\n> ### Note\n> The minimum version of `pixi` required is `0.39.0`.\n\nSince the `pixi.lock` file is stored using Git LFS, make sure you have [Git LFS](https://github.com/git-lfs/git-lfs/blob/main/INSTALLING.md) installed and properly configured on your system before installation. After cloning the repository, run:\n\n```bash\ngit lfs install && git lfs pull\n```\n\nThis ensures all LFS-tracked files are properly downloaded before you proceed with the installation.\n\nYou can add the `jaxsim` dependency in your [`pixi`][pixi] project as follows:\n\n```bash\npixi add jaxsim\n```\n\nIf you are on Linux and you want to use a `cuda`-powered version of `jax`, remember to add the appropriate line in the [`system-requirements`](https://pixi.sh/latest/reference/pixi_manifest/#the-system-requirements-table) table, i.e. adding\n\n~~~toml\n[system-requirements]\ncuda = \"13\"\n~~~\n\nif you are using a `pixi.toml` file or\n\n~~~toml\n[tool.pixi.system-requirements]\ncuda = \"13\"\n~~~\n\nif you are using a `pyproject.toml` file.\n\n</details>\n\n<details>\n<summary>With <code>pip</code></summary>\n\nYou can install the project using [`pypa/pip`][pip], preferably in a [virtual environment][venv], as follows:\n\n```bash\npip install jaxsim\n```\n\nCheck [`pyproject.toml`](pyproject.toml) for the complete list of optional dependencies.\nYou can obtain a full installation using `jaxsim[all]`.\n\nIf you need URDF support, follow the [official instructions](https://gazebosim.org/docs) to install Gazebo Sim on your operating system,\nmaking sure to obtain `sdformat ≥ 13.0` and `gz-tools ≥ 2.0`.\n\nYou don't need to install the entire Gazebo Sim suite.\nFor example, on Ubuntu, it is sufficient to install the `libsdformat*` and `gz-tools2` packages.\n\nIf you need GPU support, follow the official [installation instructions][jax_gpu] of JAX.\n\n</details>\n\n<details>\n<summary>Contributors installation (with <code>conda</code>)</summary>\n\nIf you want to contribute to the project, we recommend creating the following `jaxsim` conda environment first:\n\n```bash\nconda env create -f environment.yml\n```\n\nThen, activate the environment and install the project in editable mode:\n\n```bash\nconda activate jaxsim\npip install --no-deps -e .\n```\n\n</details>\n\n<details>\n<summary>Contributors installation (with <code>pixi</code>)</summary>\n\n> ### Note\n> The minimum version of `pixi` required is `0.39.0`.\n\nSince the `pixi.lock` file is stored using Git LFS, make sure you have [Git LFS](https://github.com/git-lfs/git-lfs/blob/main/INSTALLING.md) installed and properly configured on your system before installation. After cloning the repository, run:\n\n```bash\ngit lfs install && git lfs pull\n```\n\nThis ensures all LFS-tracked files are properly downloaded before you proceed with the installation.\n\nYou can install the default dependencies of the project using [`pixi`][pixi] as follows:\n\n```bash\npixi install\n```\n\nSee `pixi task list` for a list of available tasks.\n\n</details>\n\n[conda]: https://anaconda.org/\n[pip]: https://github.com/pypa/pip/\n[pixi]: https://pixi.sh/\n[venv]: https://docs.python.org/3/tutorial/venv.html\n[jax_gpu]: https://github.com/google/jax/#installation\n\n## Documentation\n\nThe JaxSim API documentation is available at [jaxsim.readthedocs.io][readthedocs].\n\n[readthedocs]: https://jaxsim.readthedocs.io/\n\n## Additional features\n\nJaxsim can also be used as a multi-body dynamics library! With full support for automatic differentiation of RBDAs (forwards and reverse mode) and automatic differentiation against both kinematic and dynamic parameters.\n\n### Using JaxSim as a multibody dynamics library\n\n```python\nimport pathlib\n\nimport icub_models\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\n\n# Load the iCub model\nmodel_path = icub_models.get_model_file(\"iCubGazeboV2_5\")\n\njoints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',\n          'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',\n          'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',\n          'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',\n          'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch',\n          'r_ankle_roll')\n\n# Build and reduce the model\nmodel_description = pathlib.Path(model_path)\n\nfull_model = js.model.JaxSimModel.build_from_model_description(\n    model_description=model_description,\n)\n\nmodel = js.model.reduce(model=full_model, considered_joints=joints)\n\n# Initialize model data\ndata = js.data.JaxSimModelData.build(\n    model=model,\n    base_position=jnp.array([0.0, 0.0, 1.0]),\n)\n\n# Frame and dynamics computations\nframe_index = js.frame.name_to_idx(model=model, frame_name=\"l_foot\")\n\n# Frame transformation\nW_H_F = js.frame.transform(\n    model=model, data=data, frame_index=frame_index\n)\n\n# Frame Jacobian\nW_J_F = js.frame.jacobian(\n    model=model, data=data, frame_index=frame_index\n)\n\n# Dynamics properties\nM = js.model.free_floating_mass_matrix(model=model, data=data)      # Mass matrix\nh = js.model.free_floating_bias_forces(model=model, data=data)      # Bias forces\ng = js.model.free_floating_gravity_forces(model=model, data=data)   # Gravity forces\nC = js.model.free_floating_coriolis_matrix(model=model, data=data)  # Coriolis matrix\n\n# Print dynamics results\nprint(f\"{M.shape=} \\n{h.shape=} \\n{g.shape=} \\n{C.shape=}\")\n```\n\n## Credits\n\nThe RBDAs are based on the theory of the [Rigid Body Dynamics Algorithms][RBDA]\nbook by Roy Featherstone.\nThe algorithms and some simulation features were inspired by its accompanying [code][spatial_v2].\n\n[RBDA]: https://link.springer.com/book/10.1007/978-1-4899-7560-7\n[spatial_v2]: http://royfeatherstone.org/spatial/index.html#spatial-software\n\nThe development of JaxSim started in late 2021, inspired by early versions of [`google/brax`][brax].\nAt that time, Brax was implemented in maximal coordinates, and we wanted a physics engine in reduced coordinates.\nWe are grateful to the Brax team for their work and for showing the potential of [JAX][jax] in this field.\n\nBrax v2 was later implemented with reduced coordinates, following an approach comparable to JaxSim.\nThe development then shifted to [MJX][mjx], which provides a JAX-based implementation of the Mujoco APIs.\n\nThe main differences between MJX/Brax and JaxSim are as follows:\n\n- JaxSim supports out-of-the-box all SDF models with [Pose Frame Semantics][PFS].\n- JaxSim only supports collisions between points rigidly attached to bodies and a compliant ground surface.\n\n[brax]: https://github.com/google/brax\n[mjx]: https://mujoco.readthedocs.io/en/3.0.0/mjx.html\n[PFS]: http://sdformat.org/tutorials?tut=pose_frame_semantics\n\n## Contributing\n\nWe welcome contributions from the community.\nPlease read the [contributing guide](./CONTRIBUTING.md) to get started.\n\n## Citing\n\nIf you use JaxSim in your work, please cite the following paper:\n\n```bibtex\n@article{ferretti_contact_aware_2026,\n  author       = {Filippo Luca Ferretti and Diego Ferigo and Alessandro Croci and Carlotta Sartore and Omar G. Younis and Silvio Traversaro and Daniele Pucci},\n  title        = {Contact-Aware Morphology Optimization via Physically Consistent Differentiable Simulation},\n  journal      = {IEEE Robotics and Automation Letters},\n  year         = {2026},\n  doi          = {10.1109/LRA.2026.3678125}\n}\n```\n\n## People\n\n| Authors | Maintainer |\n|:------:|:-----------:|\n| [<img src=\"https://avatars.githubusercontent.com/u/469199?v=4\" width=\"40\">][df] [<img src=\"https://avatars.githubusercontent.com/u/102977828?v=4\" width=\"40\">][ff] | [<img src=\"https://avatars.githubusercontent.com/u/102977828?v=4\" width=\"40\">][ff] |\n\n[df]: https://github.com/diegoferigo\n[ff]: https://github.com/flferretti\n\n## License\n\n[BSD3](https://choosealicense.com/licenses/bsd-3-clause/)\n"
  },
  {
    "path": "docs/Makefile",
    "content": "SPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSOURCEDIR     = .\nBUILDDIR      = _build\nSPHINXPROJ    = JAXsim\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/conf.py",
    "content": "# Configuration file for the Sphinx documentation builder.\nimport os\nimport sys\n\nif os.environ.get(\"READTHEDOCS\"):\n    checkout_name = os.path.basename(os.path.dirname(os.path.realpath(__file__)))\n    os.environ[\"CONDA_PREFIX\"] = os.path.realpath(\n        os.path.join(\"..\", \"..\", \"conda\", checkout_name)\n    )\n\nimport jaxsim\n\n# -- Version information\n\nsys.path.insert(0, os.path.abspath(\".\"))\nsys.path.insert(0, os.path.abspath(\"../\"))\nsys.path.insert(0, os.path.abspath(\"../../\"))\n\nmodule_path = os.path.abspath(\"../src/\")\nsys.path.insert(0, module_path)\n\n__version__ = jaxsim._version.__version__\n\n# -- Project information\n\nproject = \"JAXsim\"\ncopyright = \"2022, Artificial and Mechanical Intelligence\"\nauthor = \"Artificial and Mechanical Intelligence\"\n\nrelease = version = __version__\n\n# -- General configuration\n\nextensions = [\n    \"sphinx.ext.duration\",\n    \"sphinx.ext.doctest\",\n    \"sphinx.ext.autodoc\",\n    \"sphinx.ext.autosummary\",\n    \"sphinx.ext.intersphinx\",\n    \"sphinx.ext.mathjax\",\n    \"sphinx.ext.ifconfig\",\n    \"sphinx.ext.viewcode\",\n    \"sphinx_rtd_theme\",\n    \"sphinx.ext.napoleon\",\n    \"sphinx_autodoc_typehints\",\n    \"sphinx_multiversion\",\n    \"myst_nb\",\n    \"sphinx_gallery.gen_gallery\",\n    \"sphinxcontrib.collections\",\n    \"sphinx_design\",\n]\n\n# -- Options for intersphinx extension\n\nlanguage = \"en\"\n\nhtml_theme = \"sphinx_book_theme\"\n\ntemplates_path = [\"_templates\"]\n\nhtml_title = f\"JAXsim {version}\"\n\nmaster_doc = \"index\"\n\nautodoc_typehints_format = \"short\"\n\nautodoc_typehints = \"description\"\n\nautosummary_generate = True\n\nepub_show_urls = \"footnote\"\n\n# Enable postponed evaluation of annotations (PEP 563)\nautodoc_type_aliases = {\n    \"jaxsim.typing.PyTree\": \"jaxsim.typing.PyTree\",\n    \"jaxsim.typing.Vector\": \"jaxsim.typing.Vector\",\n    \"jaxsim.typing.Matrix\": \"jaxsim.typing.Matrix\",\n    \"jaxsim.typing.Array\": \"jaxsim.typing.Array\",\n    \"jaxsim.typing.Int\": \"jaxsim.typing.Int\",\n    \"jaxsim.typing.Bool\": \"jaxsim.typing.Bool\",\n    \"jaxsim.typing.Float\": \"jaxsim.typing.Float\",\n    \"jaxsim.typing.ScalarLike\": \"jaxsim.typing.ScalarLike\",\n    \"jaxsim.typing.ArrayLike\": \"jaxsim.typing.ArrayLike\",\n    \"jaxsim.typing.VectorLike\": \"jaxsim.typing.VectorLike\",\n    \"jaxsim.typing.MatrixLike\": \"jaxsim.typing.MatrixLike\",\n    \"jaxsim.typing.IntLike\": \"jaxsim.typing.IntLike\",\n    \"jaxsim.typing.BoolLike\": \"jaxsim.typing.BoolLike\",\n    \"jaxsim.typing.FloatLike\": \"jaxsim.typing.FloatLike\",\n}\n\n# -- Options for sphinx-collections\n\ncollections = {\n    \"examples\": {\"driver\": \"copy_folder\", \"source\": \"../examples/\", \"ignore\": \"assets\"}\n}\n\n# -- Options for sphinx-gallery ----------------------------------------------\n\nsphinx_gallery_conf = {\n    \"examples_dirs\": \"../examples\",\n    \"gallery_dirs\": \"../generated_examples/\",\n    \"doc_module\": \"jaxsim\",\n}\n\n# -- Options for myst -------------------------------------------------------\nmyst_enable_extensions = [\n    \"amsmath\",\n    \"dollarmath\",\n]\nnb_execution_mode = \"auto\"\nnb_execution_raise_on_error = True\nnb_render_image_options = {\n    \"scale\": \"60\",\n}\nnb_execution_timeout = 180\n\nsource_suffix = [\".rst\", \".md\", \".ipynb\"]\n\n# Ignore header warnings\nsuppress_warnings = [\"myst.header\"]\n"
  },
  {
    "path": "docs/examples.rst",
    "content": ".. _collections:\n\nExample Notebooks\n=================\n\n.. toctree::\n    :glob:\n    :hidden:\n    :maxdepth: 1\n\n    _collections/examples/README.md\n\n.. raw:: html\n\n    <div class=\"sphx-glr-thumbnails\">\n\n    <div class=\"sphx-glr-thumbcontainer\" tooltip=\"JaxSim as a hardware-accelerated parallel physics engine\">\n\n.. only:: html\n\n    :doc:`_collections/examples/jaxsim_as_physics_engine`\n\n.. raw:: html\n\n        <div class=\"sphx-glr-thumbnail-title\">JaxSim as a hardware-accelerated parallel physics engine</div>\n    </div>\n\n    <div class=\"sphx-glr-thumbcontainer\" tooltip=\"JaxSim as a hardware-accelerated parallel physics engine [Advanced]\">\n\n.. only:: html\n\n    :doc:`_collections/examples/jaxsim_as_physics_engine_advanced`\n\n.. raw:: html\n\n        <div class=\"sphx-glr-thumbnail-title\">JaxSim as a hardware-accelerated parallel physics engine [Advanced]</div>\n    </div>\n\n    <div class=\"sphx-glr-thumbcontainer\" tooltip=\"JaxSim as a multibody dynamics library\">\n\n.. only:: html\n\n    :doc:`_collections/examples/jaxsim_as_multibody_dynamics_library`\n\n.. raw:: html\n\n        <div class=\"sphx-glr-thumbnail-title\">JaxSim as a multibody dynamics library</div>\n    </div>\n\n    <div class=\"sphx-glr-thumbcontainer\" tooltip=\"JaxSim for developing closed-loop robot controllers\">\n\n.. only:: html\n\n    :doc:`_collections/examples/jaxsim_for_robot_controllers`\n\n.. raw:: html\n\n        <div class=\"sphx-glr-thumbnail-title\">JaxSim for developing closed-loop robot controllers</div>\n    </div>\n\n    </div>\n"
  },
  {
    "path": "docs/guide/configuration.rst",
    "content": "Configuration\n=============\n\nJaxSim utilizes environment variables for application configuration. Below is a detailed overview of the various configuration categories and their respective variables.\n\n\nCollision Dynamics\n~~~~~~~~~~~~~~~~~~\n\nEnvironment variables starting with ``JAXSIM_COLLISION_`` are used to configure collision dynamics. The available variables are:\n\n- ``JAXSIM_COLLISION_SPHERE_POINTS``: Specifies the number of collision points to approximate the sphere.\n\n  *Default:* ``50``.\n\n- ``JAXSIM_COLLISION_MESH_ENABLED``: Enables or disables mesh-based collision detection.\n\n  *Default:* ``False``.\n\n- ``JAXSIM_COLLISION_USE_BOTTOM_ONLY``: Limits collision detection to only the bottom half of the box or sphere.\n\n  *Default:* ``False``.\n\n.. note::\n  The bottom half is defined as the half of the box or sphere with the lowest z-coordinate in the collision link frame.\n\n\nTesting\n~~~~~~~\n\nFor testing configurations, environment variables beginning with ``JAXSIM_TEST_`` are used. The following variables are available:\n\n- ``JAXSIM_TEST_SEED``: Defines the seed for the random number generator.\n\n  *Default:* ``0``.\n\n- ``JAXSIM_TEST_AD_ORDER``: Specifies the gradient order for automatic differentiation tests.\n\n  *Default:* ``1``.\n\n- ``JAXSIM_TEST_FD_STEP_SIZE``: Sets the step size for finite difference tests.\n\n  *Default:* the cube root of the machine epsilon.\n\n\nJoint Dynamics\n~~~~~~~~~~~~~~\nJoint dynamics are configured using environment variables starting with ``JAXSIM_JOINT_``. Available variables include:\n\n- ``JAXSIM_JOINT_POSITION_LIMIT_DAMPER``: Overrides the damper value for joint position limits of the SDF model.\n\n- ``JAXSIM_JOINT_POSITION_LIMIT_SPRING``: Overrides the spring value for joint position limits of the SDF model.\n\n\nLogging and Exceptions\n~~~~~~~~~~~~~~~~~~~~~~\n\nThe logging and exceptions configurations is controlled by the following environment variables:\n\n- ``JAXSIM_LOGGING_LEVEL``: Determines the logging level.\n\n  *Default:* ``DEBUG`` for development, ``WARNING`` for production.\n\n- ``JAXSIM_ENABLE_EXCEPTIONS``: Enables the runtime checks and exceptions. Note that enabling exceptions might lead to device-to-host transfer of data, increasing the computational time required.\n\n  *Default:* ``False``.\n\n.. note::\n    Runtime exceptions are disabled by default on TPU.\n"
  },
  {
    "path": "docs/guide/install.rst",
    "content": "Installation\n============\n\n.. _installation:\n\nPrerequisites\n-------------\n\nJAXsim requires Python 3.11 or later.\n\nBasic Installation\n------------------\n\nYou can install the project with using `conda`_:\n\n.. code-block:: bash\n\n   conda install jaxsim -c conda-forge\n\nAlternatively, you can use `pypa/pip`_, preferably in a `virtual environment`_:\n\n.. code-block:: bash\n\n   pip install jaxsim\n\nHave a look to `pyproject.toml`_ for a complete list of optional dependencies.\nYou can install all by using ``pip install \"jaxsim[all]\"``.\n.. note::\n\n    If you need GPU support, please follow the official `installation instruction`_ of JAX.\n\n.. _conda: https://anaconda.org/\n.. _pyproject.toml: https://github.com/gbionics/jaxsim/blob/main/pyproject.toml\n.. _pypa/pip: https://github.com/pypa/pip/\n.. _virtual environment: https://docs.python.org/3.8/tutorial/venv.html\n.. _installation instruction: https://github.com/google/jax/#installation\n"
  },
  {
    "path": "docs/index.rst",
    "content": "JAXsim\n#######\n\nA scalable physics engine and multibody dynamics library implemented with JAX. With JIT batteries 🔋\n\n.. note::\n    This simulator currently focuses on locomotion applications. Only contacts with ground are supported.\n\nFeatures\n--------\n\n.. grid::\n\n   .. grid-item::\n      :columns: 12 12 12 6\n\n      .. card:: Performance\n         :class-card: sd-border-0\n         :shadow: none\n         :class-title: sd-fs-5\n\n         .. div:: sd-font-normal\n\n            Physics engine in reduced coordinates implemented with JAX_.\n            Compatibility with JIT compilation for increased performance and transparent support to execute logic on CPUs, GPUs, and TPUs.\n            Parallel multi-body simulations on hardware accelerators for significantly increased throughput\n\n   .. grid-item::\n      :columns: 12 12 12 6\n\n      .. card:: Model Parsing\n         :class-card: sd-border-0\n         :shadow: none\n         :class-title: sd-fs-5\n\n         .. div:: sd-font-normal\n\n            Support for SDF models (and, upon conversion, URDF models). Revolute, prismatic, and fixed joints supported.\n\n   .. grid-item::\n      :columns: 12 12 12 6\n\n      .. card:: Automatic Differentiation\n         :class-card: sd-border-0\n         :shadow: none\n         :class-title: sd-fs-5\n\n         .. div:: sd-font-normal\n\n            Support for automatic differentiation of rigid body dynamics algorithms (RBDAs) for model-based robotics research.\n            Soft contacts model supporting full friction cone and sticking / slipping transition.\n\n   .. grid-item::\n      :columns: 12 12 12 6\n\n      .. card:: Complex Dynamics\n         :class-card: sd-border-0\n         :shadow: none\n         :class-title: sd-fs-5\n\n         .. div:: sd-font-normal\n\n            JAXsim provides a variety of integrators for the simulation of multibody dynamics, including RK4, Heun, Euler, and more.\n            Support of `multiple velocities representations <https://research.tue.nl/en/publications/multibody-dynamics-notation-version-2>`_.\n\n\n----\n\n.. toctree::\n  :hidden:\n\n  guide/install\n  guide/configuration\n\n  examples\n\n.. toctree::\n  :hidden:\n  :maxdepth: 2\n  :caption: JAXsim API\n\n  modules/api\n  modules/math\n  modules/mujoco\n  modules/parsers\n  modules/rbda\n  modules/typing\n  modules/utils\n\nExamples\n--------\n\nExplore and learn how to use the library through practical demonstrations available in the `examples <https://github.com/gbionics/jaxsim/tree/main/examples>`__ folder.\n\nCredits\n-------\n\nThe physics module of JAXsim is based on the theory of the `Rigid Body Dynamics Algorithms <https://link.springer.com/book/10.1007/978-1-4899-7560-7>`_ book by Roy Featherstone.\nWe structured part of our logic following its accompanying `code <http://royfeatherstone.org/spatial/index.html#spatial-software>`_.\nThe physics engine is developed entirely in Python using JAX_.\n\nThe inspiration for developing JAXsim originally stemmed from early versions of Brax_.\nHere below we summarize the differences between the projects:\n\n- JAXsim simulates multibody dynamics in reduced coordinates, while :code:`brax v1` uses maximal coordinates.\n- The new v2 APIs of brax (and the new MJX_) were then implemented in reduced coordinates, following an approach comparable to JAXsim, with major differences in contact handling.\n- The rigid-body algorithms used in JAXsim allow to efficiently compute quantities based on the Euler-Poincarè\n  formulation of the equations of motion, necessary for model-based robotics research.\n- JAXsim supports SDF (and, indirectly, URDF) models, assuming the model is described with the\n  recent `Pose Frame Semantics <http://sdformat.org/tutorials?tut=pose_frame_semantics>`_.\n- Contrarily to brax, JAXsim only supports collision detection between bodies and a compliant ground surface.\n- The RBDAs of JAXsim support automatic differentiation, but this functionality has not been thoroughly tested.\n\n\nPeople\n------\n\nAuthors\n'''''''\n\n`Diego Ferigo <https://github.com/diegoferigo>`_\n`Filippo Luca Ferretti <https://github.com/flferretti>`_\n\nMaintainers\n'''''''''''\n\n`Filippo Luca Ferretti <https://github.com/flferretti>`_\n`Alessandro Croci <https://github.com/xela-95>`_\n\nLicense\n-------\n\n`BSD3 <https://choosealicense.com/licenses/bsd-3-clause/>`_\n\n.. _Brax: https://github.com/google/brax\n.. _MJX: https://mujoco.readthedocs.io/en/3.0.0/mjx.html\n.. _JAX: https://github.com/google/jax\n"
  },
  {
    "path": "docs/make.bat",
    "content": "@ECHO OFF\r\n\r\npushd %~dp0\r\n\r\nREM\r\n\r\nif \"%SPHINXBUILD%\" == \"\" (\r\n\tset SPHINXBUILD=sphinx-build\r\n)\r\nset SOURCEDIR=.\r\nset BUILDDIR=_build\r\nset SPHINXPROJ=JaxSim\r\n\r\nif \"%1\" == \"\" goto help\r\n\r\n%SPHINXBUILD% >NUL 2>NUL\r\nif errorlevel 9009 (\r\n\techo.\r\n\techo.The 'sphinx-build' command was not found. Make sure you have Sphinx\r\n\techo.installed, then set the SPHINXBUILD environment variable to point\r\n\techo.to the full path of the 'sphinx-build' executable. Alternatively you\r\n\techo.may add the Sphinx directory to PATH.\r\n\techo.\r\n\techo.If you don't have Sphinx installed, grab it from\r\n\techo.http://sphinx-doc.org/\r\n\texit /b 1\r\n)\r\n\r\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\ngoto end\r\n\r\n:help\r\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\n\r\n:end\r\npopd\r\n"
  },
  {
    "path": "docs/modules/api.rst",
    "content": "Functional API\n==============\n\n.. currentmodule:: jaxsim.api\n\n.. autosummary::\n   :toctree: _autosummary\n\n    model\n    data\n    contact\n    kin_dyn_parameters\n    integrators\n    joint\n    link\n    frame\n    com\n    ode\n    references\n    actuation_model\n    common\n\n\nModel\n~~~~~\n\n.. automodule:: jaxsim.api.model\n    :members:\n    :no-index:\n\n.. automodule:: jaxsim.api.actuation_model\n    :members:\n    :no-index:\n\nData\n~~~~\n\n.. automodule:: jaxsim.api.data\n    :members:\n    :no-index:\n\nContact\n~~~~~~~\n\n.. automodule:: jaxsim.api.contact\n    :members:\n    :no-index:\n\nKinDynParameters\n~~~~~~~~~~~~~~~~\n\n.. automodule:: jaxsim.api.kin_dyn_parameters\n    :members:\n    :no-index:\n\nJoint\n~~~~~\n\n.. automodule:: jaxsim.api.joint\n    :members:\n    :no-index:\n\nLink\n~~~~~\n\n.. automodule:: jaxsim.api.link\n    :members:\n    :no-index:\n\nFrame\n~~~~~\n\n.. automodule:: jaxsim.api.frame\n    :members:\n    :no-index:\n\nCoM\n~~~\n\n.. automodule:: jaxsim.api.com\n    :members:\n    :no-index:\n\nIntegration\n~~~~~~~~~~~\n\n.. automodule:: jaxsim.api.integrators\n    :members:\n    :no-index:\n\n\n.. automodule:: jaxsim.api.ode\n    :members:\n    :no-index:\n\nReferences\n~~~~~~~~~~\n\n.. automodule:: jaxsim.api.references\n    :members:\n    :no-index:\n\nCommon\n~~~~~~\n\n.. autoclass:: jaxsim.api.common.VelRepr\n    :members:\n\n.. autoclass:: jaxsim.api.common.ModelDataWithVelocityRepresentation\n    :members:\n"
  },
  {
    "path": "docs/modules/math.rst",
    "content": "Math\n====\n\n.. currentmodule:: jaxsim.math\n\n.. automodule:: jaxsim.math.adjoint\n    :members:\n    :undoc-members:\n\n.. automodule:: jaxsim.math.cross\n    :members:\n    :undoc-members:\n\n.. automodule:: jaxsim.math.inertia\n    :members:\n    :undoc-members:\n\n.. automodule:: jaxsim.math.quaternion\n    :members:\n    :undoc-members:\n\n.. automodule:: jaxsim.math.rotation\n    :members:\n    :undoc-members:\n\n.. automodule:: jaxsim.math.skew\n    :members:\n    :undoc-members:\n"
  },
  {
    "path": "docs/modules/mujoco.rst",
    "content": "MuJoCo Visualizer\n==================\n\nJAXsim provides a simple interface with MuJoCo's visualizer. The visualizer is\na separate process that communicates with the main simulation process. This\nallows for the simulation to run at full speed while the visualizer can run at\na different frame rate.\n\n.. currentmodule:: jaxsim.mujoco\n\nLoaders\n~~~~~~~\n\n.. automodule:: jaxsim.mujoco.loaders\n    :members:\n\nModel\n~~~~~\n\n.. automodule:: jaxsim.mujoco.model\n    :members:\n\nVisualizer\n~~~~~~~~~~\n\n.. automodule:: jaxsim.mujoco.visualizer\n    :members:\n"
  },
  {
    "path": "docs/modules/parsers.rst",
    "content": "Parsers\n=======\n\n.. automodule:: jaxsim.parsers.descriptions.collision\n    :members:\n\n.. automodule:: jaxsim.parsers.descriptions.joint\n    :members:\n\n.. automodule:: jaxsim.parsers.descriptions.link\n    :members:\n\n.. automodule:: jaxsim.parsers.descriptions.model\n    :members:\n"
  },
  {
    "path": "docs/modules/rbda.rst",
    "content": "Rigid Body Dynamics Algorithms\n==============================\n\nThis module provides a set of algorithms for rigid body dynamics.\n\n.. currentmodule:: jaxsim.rbda\n\n.. autosummary::\n    :toctree: _autosummary\n\n    aba\n    collidable_points\n    contacts.soft\n    contacts.rigid\n    contacts.relaxed_rigid\n    crba\n    forward_kinematics\n    jacobian\n    utils\n\nCollision Detection\n~~~~~~~~~~~~~~~~~~~\n\n.. automodule:: jaxsim.rbda.collidable_points\n    :members:\n    :no-index:\n\nContact Models\n~~~~~~~~~~~~~~\n\n.. automodule:: jaxsim.rbda.contacts.soft\n    :members:\n    :no-index:\n\n.. automodule:: jaxsim.rbda.contacts.rigid\n    :members:\n    :no-index:\n\n.. automodule:: jaxsim.rbda.contacts.relaxed_rigid\n    :members:\n    :no-index:\n\nUtilities\n~~~~~~~~~\n\n.. automodule:: jaxsim.rbda.utils\n    :members:\n    :no-index:\n"
  },
  {
    "path": "docs/modules/typing.rst",
    "content": "Typing\n======\n\n.. currentmodule:: jaxsim.typing\n\n.. autosummary::\n    PyTree\n    Matrix\n    Bool\n    Int\n    Float\n    Vector\n    BoolLike\n    FloatLike\n    IntLike\n    ArrayLike\n    VectorLike\n    MatrixLike\n"
  },
  {
    "path": "docs/modules/utils.rst",
    "content": "Utils\n=====\n\n.. automodule:: jaxsim.utils\n    :members:\n    :inherited-members:\n\n.. autoclass:: jaxsim.utils.JaxsimDataclass\n    :members:\n    :inherited-members:\n"
  },
  {
    "path": "environment.yml",
    "content": "name: jaxsim\nchannels:\n  - conda-forge\ndependencies:\n  # ===========================\n  # Dependencies from setup.cfg\n  # ===========================\n  - python >= 3.12.0\n  - coloredlogs\n  - jax >= 0.4.34\n  - jaxlib >= 0.4.34\n  - jaxlie >= 1.3.0\n  - jax-dataclasses >= 1.4.0\n  - optax >= 0.2.3\n  - pptree\n  - qpax\n  - rod >= 0.3.3\n  - trimesh\n  - typing_extensions # python<3.12\n  # ====================================\n  # Optional dependencies from setup.cfg\n  # ====================================\n  # [testing]\n  - chex\n  - idyntree >= 12.2.1\n  - pytest\n  - pytest-benchmark\n  - pytest-icdiff\n  - robot_descriptions >= 1.16.0\n  - icub-models\n  # [viz]\n  - lxml\n  - mediapy\n  - mujoco >= 3.0.0\n  - scipy >= 1.14.0\n  # ==========================\n  # Documentation dependencies\n  # ==========================\n  - cachecontrol\n  - filecache\n  - jinja2\n  - myst-nb\n  - pip\n  - sphinx\n  - sphinx-autodoc-typehints\n  - sphinx-book-theme\n  - sphinx-copybutton\n  - sphinx-design\n  - sphinx_fontawesome\n  - sphinx-gallery\n  - sphinx-jinja2-compat\n  - sphinx-multiversion\n  - sphinx_rtd_theme\n  - sphinx-toolbox\n  - icub-models\n  # ========================================\n  # Other dependencies for GitHub Codespaces\n  # ========================================\n  - ipython\n  - pip:\n    - sphinx-collections # TODO (flferretti): PR to conda-forge\n"
  },
  {
    "path": "examples/.gitattributes",
    "content": "# GitHub syntax highlighting\npixi.lock linguist-language=YAML\n"
  },
  {
    "path": "examples/.gitignore",
    "content": "# pixi environments\n.pixi\n"
  },
  {
    "path": "examples/README.md",
    "content": "# JaxSim Examples\n\nThis folder contains Jupyter notebooks that demonstrate the practical usage of JaxSim.\n\n## Featured examples\n\n| Notebook | Google Colab | Description |\n| :--- | :---: | :--- |\n| [`jaxsim_as_multibody_dynamics_library`](./jaxsim_as_multibody_dynamics_library.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_as_multibody_dynamics] | An example demonstrating how to use JaxSim as a multibody dynamics library. |\n| [`jaxsim_as_physics_engine.ipynb`](./jaxsim_as_physics_engine.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_as_physics_engine] | An example demonstrating how to simulate vectorized models in parallel. |\n| [`jaxsim_as_physics_engine_advanced.ipynb`](./jaxsim_as_physics_engine_advanced.ipynb) | [![Open In Colab][colab_badge]][jaxsim_as_physics_engine_advanced] | An example showcasing advanced JaxSim usage, such as customizing the integrator, contact model, and more. |\n| [`jaxsim_for_robot_controllers.ipynb`](./jaxsim_for_robot_controllers.ipynb) | [![Open In Colab][colab_badge]][ipynb_jaxsim_closed_loop] | A basic example showing how to simulate a PD controller with gravity compensation for a 2-DOF cart-pole. |\n\n[colab_badge]: https://colab.research.google.com/assets/colab-badge.svg\n[ipynb_jaxsim_closed_loop]: https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_for_robot_controllers.ipynb\n[ipynb_jaxsim_as_physics_engine]: https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_as_physics_engine.ipynb\n[jaxsim_as_physics_engine_advanced]: https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_as_physics_engine_advanced.ipynb\n[ipynb_jaxsim_as_multibody_dynamics]: https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_as_multibody_dynamics_library.ipynb\n\n## How to run the examples\n\nYou can run the JaxSim examples with hardware acceleration in two ways.\n\n### Option 1: Google Colab (recommended)\n\nThe easiest way is to use the provided Google Colab links to run the notebooks in a hosted environment\nwith no setup required.\n\n### Option 2: Local execution with `pixi`\n\nTo run the examples locally, first install `pixi` following the [official documentation][pixi_installation]:\n\n[pixi_installation]: https://pixi.sh/#installation\n\n```bash\ncurl -fsSL https://pixi.sh/install.sh | bash\n```\n\nThen, from the repository's root directory, execute the example notebooks using:\n\n```bash\npixi run examples\n```\n\nThis command will automatically handle all necessary dependencies and run the examples in a self-contained environment.\n"
  },
  {
    "path": "examples/assets/build_cartpole_urdf.py",
    "content": "import os\n\nif \"ROD_LOGGING_LEVEL\" not in os.environ:\n    os.environ[\"ROD_LOGGING_LEVEL\"] = \"WARNING\"\n\nimport numpy as np\nimport rod.kinematics.tree_transforms\nfrom rod.builder import primitives\n\nif __name__ == \"__main__\":\n\n    # ================\n    # Model parameters\n    # ================\n\n    # Rail parameters.\n    rail_height = 1.2\n    rail_length = 5.0\n    rail_radius = 0.005\n    rail_mass = 5.0\n\n    # Cart parameters.\n    cart_mass = 1.0\n    cart_size = (0.1, 0.2, 0.05)\n\n    # Pole parameters.\n    pole_mass = 0.5\n    pole_length = 1.0\n    pole_radius = 0.005\n\n    # ========================\n    # Create the link builders\n    # ========================\n\n    rail_builder = primitives.CylinderBuilder(\n        name=\"rail\",\n        mass=rail_mass,\n        radius=rail_radius,\n        length=rail_length,\n    )\n\n    cart_builder = primitives.BoxBuilder(\n        name=\"cart\",\n        mass=cart_mass,\n        x=cart_size[0],\n        y=cart_size[1],\n        z=cart_size[2],\n    )\n\n    pole_builder = primitives.CylinderBuilder(\n        name=\"pole\",\n        mass=pole_mass,\n        radius=pole_radius,\n        length=pole_length,\n    )\n\n    # =================\n    # Create the joints\n    # =================\n\n    world_to_rail = rod.Joint(\n        name=\"world_to_rail\",\n        type=\"fixed\",\n        parent=\"world\",\n        child=rail_builder.name,\n        pose=primitives.PrimitiveBuilder.build_pose(\n            relative_to=\"world\",\n        ),\n    )\n\n    linear = rod.Joint(\n        name=\"linear\",\n        type=\"prismatic\",\n        parent=rail_builder.name,\n        child=cart_builder.name,\n        pose=primitives.PrimitiveBuilder.build_pose(\n            relative_to=rail_builder.name,\n            pos=np.array([0, 0, rail_height]),\n        ),\n        axis=rod.Axis(\n            xyz=rod.Xyz(xyz=[0, 1, 0]),\n            limit=rod.Limit(\n                upper=(rail_length / 2 - cart_size[1] / 2),\n                lower=-(rail_length / 2 - cart_size[1] / 2),\n                effort=500.0,\n                velocity=10.0,\n            ),\n        ),\n    )\n\n    pivot = rod.Joint(\n        name=\"pivot\",\n        type=\"continuous\",\n        parent=cart_builder.name,\n        child=pole_builder.name,\n        pose=primitives.PrimitiveBuilder.build_pose(\n            relative_to=cart_builder.name,\n        ),\n        axis=rod.Axis(\n            xyz=rod.Xyz(xyz=[1, 0, 0]),\n            limit=rod.Limit(),\n        ),\n    )\n\n    # ================\n    # Create the links\n    # ================\n\n    rail_elements_pose = primitives.PrimitiveBuilder.build_pose(\n        pos=np.array([0, 0, rail_height]),\n        rpy=np.array([np.pi / 2, 0, 0]),\n    )\n\n    rail = (\n        rail_builder.build_link(\n            name=rail_builder.name,\n            pose=primitives.PrimitiveBuilder.build_pose(\n                relative_to=world_to_rail.name,\n            ),\n        )\n        .add_inertial(pose=rail_elements_pose)\n        .add_visual(pose=rail_elements_pose)\n        .add_collision(pose=rail_elements_pose)\n        .build()\n    )\n\n    cart = (\n        cart_builder.build_link(\n            name=cart_builder.name,\n            pose=primitives.PrimitiveBuilder.build_pose(relative_to=linear.name),\n        )\n        .add_inertial()\n        .add_visual()\n        .add_collision()\n        .build()\n    )\n\n    pole_elements_pose = primitives.PrimitiveBuilder.build_pose(\n        pos=np.array([0, 0, pole_length / 2]),\n    )\n\n    pole = (\n        pole_builder.build_link(\n            name=pole_builder.name,\n            pose=primitives.PrimitiveBuilder.build_pose(\n                relative_to=pivot.name,\n            ),\n        )\n        .add_inertial(pose=pole_elements_pose)\n        .add_visual(pose=pole_elements_pose)\n        .add_collision(pose=pole_elements_pose)\n        .build()\n    )\n\n    # ===========\n    # Build model\n    # ===========\n\n    # Create ROD in-memory model.\n    model = rod.Model(\n        name=\"cartpole\",\n        canonical_link=rail.name,\n        link=[\n            rail,\n            cart,\n            pole,\n        ],\n        joint=[\n            world_to_rail,\n            linear,\n            pivot,\n        ],\n    )\n\n    # Update the pose elements to be closer to those expected in URDF.\n    model.switch_frame_convention(\n        frame_convention=rod.FrameConvention.Urdf, explicit_frames=True\n    )\n\n    # ==============\n    # Get SDF string\n    # ==============\n\n    # Create the top-level SDF object.\n    sdf = rod.Sdf(version=\"1.10\", model=model)\n\n    # Generate the SDF string.\n    # sdf_string = sdf.serialize(pretty=True, validate=True)\n\n    # ===============\n    # Get URDF string\n    # ===============\n\n    import rod.urdf.exporter\n\n    # Convert the SDF to URDF.\n    urdf_string = rod.urdf.exporter.UrdfExporter(\n        pretty=True, indent=\"    \"\n    ).to_urdf_string(sdf=sdf)\n\n    # Print the URDF string.\n    print(urdf_string)\n"
  },
  {
    "path": "examples/assets/cartpole.urdf",
    "content": "<?xml version=\"1.0\" encoding=\"utf-8\"?>\n<robot name=\"cartpole\">\n    <link name=\"world\"/>\n    <link name=\"rail\">\n        <inertial>\n            <origin xyz=\"0.0 0.0 1.2\" rpy=\"1.5707963267948963 0.0 0.0\"/>\n            <mass value=\"5.0\"/>\n            <inertia ixx=\"10.416697916666665\" ixy=\"0.0\" ixz=\"0.0\" iyy=\"10.416697916666665\" iyz=\"0.0\" izz=\"6.25e-05\"/>\n        </inertial>\n        <visual name=\"rail_visual\">\n            <origin xyz=\"0.0 0.0 1.2\" rpy=\"1.5707963267948963 0.0 0.0\"/>\n            <geometry>\n                <cylinder radius=\"0.005\" length=\"5.0\"/>\n            </geometry>\n        </visual>\n        <collision name=\"rail_collision\">\n            <origin xyz=\"0.0 0.0 1.2\" rpy=\"1.5707963267948963 0.0 0.0\"/>\n            <geometry>\n                <cylinder radius=\"0.005\" length=\"5.0\"/>\n            </geometry>\n        </collision>\n    </link>\n    <link name=\"cart\">\n        <inertial>\n            <origin xyz=\"0.0 0.0 0.0\" rpy=\"0.0 0.0 0.0\"/>\n            <mass value=\"1.0\"/>\n            <inertia ixx=\"0.0035416666666666674\" ixy=\"0.0\" ixz=\"0.0\" iyy=\"0.0010416666666666669\" iyz=\"0.0\" izz=\"0.0041666666666666675\"/>\n        </inertial>\n        <visual name=\"cart_visual\">\n            <origin xyz=\"0.0 0.0 0.0\" rpy=\"0.0 0.0 0.0\"/>\n            <geometry>\n                <box size=\"0.1 0.2 0.05\"/>\n            </geometry>\n        </visual>\n        <collision name=\"cart_collision\">\n            <origin xyz=\"0.0 0.0 0.0\" rpy=\"0.0 0.0 0.0\"/>\n            <geometry>\n                <box size=\"0.1 0.2 0.05\"/>\n            </geometry>\n        </collision>\n    </link>\n    <link name=\"pole\">\n        <inertial>\n            <origin xyz=\"0.0 0.0 0.5\" rpy=\"0.0 0.0 0.0\"/>\n            <mass value=\"0.5\"/>\n            <inertia ixx=\"0.04166979166666667\" ixy=\"0.0\" ixz=\"0.0\" iyy=\"0.04166979166666667\" iyz=\"0.0\" izz=\"6.25e-06\"/>\n        </inertial>\n        <visual name=\"pole_visual\">\n            <origin xyz=\"0.0 0.0 0.5\" rpy=\"0.0 0.0 0.0\"/>\n            <geometry>\n                <cylinder radius=\"0.005\" length=\"1.0\"/>\n            </geometry>\n        </visual>\n        <collision name=\"pole_collision\">\n            <origin xyz=\"0.0 0.0 0.5\" rpy=\"0.0 0.0 0.0\"/>\n            <geometry>\n                <cylinder radius=\"0.005\" length=\"1.0\"/>\n            </geometry>\n        </collision>\n    </link>\n    <link name=\"cart_frame\"/>\n    <link name=\"rail_frame\"/>\n    <joint name=\"cart_frame_joint\" type=\"fixed\">\n        <parent link=\"cart\" />\n        <child link=\"cart_frame\" />\n        <origin xyz=\"0.0 0.0 0.0\" rpy=\"0.0 0.0 0.0\" />\n    </joint>\n    <joint name=\"rail_frame_joint\" type=\"fixed\">\n        <parent link=\"rail\" />\n        <child link=\"rail_frame\" />\n        <origin xyz=\"0.0 0.0 1.2\" rpy=\"0.0 0.0 0.0\" />\n    </joint>\n    <joint name=\"world_to_rail\" type=\"fixed\">\n        <origin xyz=\"0.0 0.0 0.0\" rpy=\"0.0 0.0 0.0\"/>\n        <parent link=\"world\"/>\n        <child link=\"rail\"/>\n    </joint>\n    <joint name=\"linear\" type=\"prismatic\">\n        <origin xyz=\"0.0 0.0 1.2\" rpy=\"0.0 0.0 0.0\"/>\n        <parent link=\"rail\"/>\n        <child link=\"cart\"/>\n        <axis xyz=\"0 1 0\"/>\n        <limit effort=\"500.0\" velocity=\"10.0\" lower=\"-2.4\" upper=\"2.4\"/>\n    </joint>\n    <joint name=\"pivot\" type=\"continuous\">\n        <origin xyz=\"0.0 0.0 0.0\" rpy=\"0.0 0.0 0.0\"/>\n        <parent link=\"cart\"/>\n        <child link=\"pole\"/>\n        <axis xyz=\"1 0 0\"/>\n        <limit effort=\"3.4028235e+38\" velocity=\"3.4028235e+38\"/>\n    </joint>\n</robot>\n"
  },
  {
    "path": "examples/jaxsim_as_multibody_dynamics_library.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"DpLq0-lltwZ1\"\n   },\n   \"source\": [\n    \"# `JaxSim` as a multibody dynamics library\\n\",\n    \"\\n\",\n    \"JaxSim was initially developed as a **hardware-accelerated physics engine**. Over time, it has evolved, adding new features to become a comprehensive **JAX-based multibody dynamics library**.\\n\",\n    \"\\n\",\n    \"In this notebook, you'll explore the main APIs for loading robot models and computing key quantities for applications such as control, planning, and more.\\n\",\n    \"\\n\",\n    \"A key advantage of JaxSim is its ability to create fully differentiable closed-loop systems, enabling end-to-end optimization. Combined with the flexibility to parameterize model kinematics and dynamics, JaxSim can serve as an excellent playground for robot learning applications.\\n\",\n    \"\\n\",\n    \"<a target=\\\"_blank\\\" href=\\\"https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_as_multibody_dynamics_library.ipynb\\\">\\n\",\n    \"  <img src=\\\"https://colab.research.google.com/assets/colab-badge.svg\\\" alt=\\\"Open In Colab\\\"/>\\n\",\n    \"</a>\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"rcEwprINtwZ3\"\n   },\n   \"source\": [\n    \"## Prepare environment\\n\",\n    \"\\n\",\n    \"First, we need to install the necessary packages and import their resources.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"u4xL7dbBtwZ3\",\n    \"outputId\": \"1a088e28-e005-4910-928c-cb641e589ab5\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# @title Imports and setup\\n\",\n    \"from IPython.display import clear_output\\n\",\n    \"import sys\\n\",\n    \"\\n\",\n    \"IS_COLAB = \\\"google.colab\\\" in sys.modules\\n\",\n    \"\\n\",\n    \"# Install JAX, sdformat, and other notebook dependencies.\\n\",\n    \"if IS_COLAB:\\n\",\n    \"    !{sys.executable} -m pip install --pre -qU jaxsim\\n\",\n    \"    !{sys.executable} -m pip install robot_descriptions>=1.16.0\\n\",\n    \"    !apt install -qq lsb-release wget gnupg\\n\",\n    \"    !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\\n\",\n    \"    !echo \\\"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\\\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\\n\",\n    \"    !apt -qq update\\n\",\n    \"    !apt install -qq --no-install-recommends libsdformat13 gz-tools2\\n\",\n    \"\\n\",\n    \"    clear_output()\\n\",\n    \"\\n\",\n    \"import os\\n\",\n    \"import pathlib\\n\",\n    \"\\n\",\n    \"import jax\\n\",\n    \"import jax.numpy as jnp\\n\",\n    \"import jaxsim.api as js\\n\",\n    \"import jaxsim.math\\n\",\n    \"from jaxsim import logging\\n\",\n    \"from jaxsim import VelRepr\\n\",\n    \"\\n\",\n    \"logging.set_logging_level(logging.LoggingLevel.WARNING)\\n\",\n    \"print(f\\\"Running on {jax.devices()}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"fN8Xg4QgtwZ4\"\n   },\n   \"source\": [\n    \"## Robot model\\n\",\n    \"\\n\",\n    \"JaxSim allows loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files.\\n\",\n    \"\\n\",\n    \"In this example, we will use the [ErgoCub][ergocub] humanoid robot model. If you have a URDF/SDF file for your robot that is compatible with [`gazebosim/sdformat`][sdformat_github][1], it should work out-of-the-box with JaxSim.\\n\",\n    \"\\n\",\n    \"[sdformat]: http://sdformat.org/\\n\",\n    \"[urdf]: http://wiki.ros.org/urdf/\\n\",\n    \"[ergocub]: https://ergocub.eu/\\n\",\n    \"[sdformat_github]: https://github.com/gazebosim/sdformat\\n\",\n    \"\\n\",\n    \"---\\n\",\n    \"\\n\",\n    \"[1]: JaxSim validates robot descriptions using the command `gz sdf -p /path/to/file.urdf`. Ensure this command runs successfully on your file.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"rB0BFxyPtwZ5\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# @title Fetch the URDF file\\n\",\n    \"\\n\",\n    \"try:\\n\",\n    \"    os.environ[\\\"ROBOT_DESCRIPTION_COMMIT\\\"] = \\\"v0.7.7\\\"\\n\",\n    \"\\n\",\n    \"    import robot_descriptions.ergocub_description\\n\",\n    \"\\n\",\n    \"finally:\\n\",\n    \"    _ = os.environ.pop(\\\"ROBOT_DESCRIPTION_COMMIT\\\", None)\\n\",\n    \"\\n\",\n    \"model_description_path = pathlib.Path(\\n\",\n    \"    robot_descriptions.ergocub_description.URDF_PATH.replace(\\n\",\n    \"        \\\"ergoCubSN002\\\", \\\"ergoCubSN001\\\"\\n\",\n    \"    )\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"clear_output()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"jeTUZic8twZ5\"\n   },\n   \"source\": [\n    \"### Create the model and its data\\n\",\n    \"\\n\",\n    \"The dynamics of a generic floating-base model are governed by the following equations of motion:\\n\",\n    \"\\n\",\n    \"$$\\n\",\n    \"M(\\\\mathbf{q}) \\\\dot{\\\\boldsymbol{\\\\nu}} + \\\\mathbf{h}(\\\\mathbf{q}, \\\\boldsymbol{\\\\nu}) = B \\\\boldsymbol{\\\\tau} + \\\\sum_{L_i \\\\in \\\\mathcal{L}} J_{W,L_i}^\\\\top(\\\\mathbf{q}) \\\\: \\\\mathbf{f}_i\\n\",\n    \".\\n\",\n    \"$$\\n\",\n    \"\\n\",\n    \"Here, the system state is represented by:\\n\",\n    \"\\n\",\n    \"- $\\\\mathbf{q} = ({}^W \\\\mathbf{p}_B, \\\\, \\\\mathbf{s}) \\\\in \\\\text{SE}(3) \\\\times \\\\mathbb{R}^n$ is the generalized position.\\n\",\n    \"- $\\\\boldsymbol{\\\\nu} = (\\\\boldsymbol{v}_{W,B}, \\\\, \\\\boldsymbol{\\\\omega}_{W,B}, \\\\, \\\\dot{\\\\mathbf{s}}) \\\\in \\\\mathbb{R}^{6+n}$ is the generalized velocity.\\n\",\n    \"\\n\",\n    \"The inputs to the system are:\\n\",\n    \"\\n\",\n    \"- $\\\\boldsymbol{\\\\tau} \\\\in \\\\mathbb{R}^n$ are the joint torques.\\n\",\n    \"- $\\\\mathbf{f}_i \\\\in \\\\mathbb{R}^6$ is the 6D force applied to the link $L_i$.\\n\",\n    \"\\n\",\n    \"JaxSim exposes functional APIs to operate over the following two main data structures:\\n\",\n    \"\\n\",\n    \"- **`JaxSimModel`** stores all the constant information parsed from the model description.\\n\",\n    \"- **`JaxSimModelData`** holds the state of model.\\n\",\n    \"\\n\",\n    \"Additionally, JaxSim includes a utility class, **`JaxSimModelReferences`**, for managing and manipulating system inputs.\\n\",\n    \"\\n\",\n    \"---\\n\",\n    \"\\n\",\n    \"This notebook uses the notation summarized in the following report. Please refer to this document if you have any questions or if something is unclear.\\n\",\n    \"\\n\",\n    \"> Traversaro and Saccon, **Multibody dynamics notation**, 2019, [URL](https://pure.tue.nl/ws/portalfiles/portal/139293126/A_Multibody_Dynamics_Notation_Revision_2_.pdf).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"WYgBAxU0twZ6\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Create the model from the model description.\\n\",\n    \"# JaxSim removes all fixed joints by lumping together their parent and child links.\\n\",\n    \"full_model = js.model.JaxSimModel.build_from_model_description(\\n\",\n    \"    model_description=model_description_path\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"DdaETmDStwZ6\"\n   },\n   \"source\": [\n    \"It is often useful to work with only a subset of joints, referred to as the _considered joints_. JaxSim allows to reduce a model so that the computation of the rigid body dynamics quantities is simplified.\\n\",\n    \"\\n\",\n    \"By default, the positions of the removed joints are considered to be zero. If this is not the case, the `reduce` function accepts a dictionary `dict[str, float]` to specify custom joint positions.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"QuhG7Zv5twZ7\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"model = js.model.reduce(\\n\",\n    \"    model=full_model,\\n\",\n    \"    considered_joints=tuple(\\n\",\n    \"        j\\n\",\n    \"        for j in full_model.joint_names()\\n\",\n    \"        # Remove sensor joints.\\n\",\n    \"        if \\\"camera\\\" not in j\\n\",\n    \"        # Remove head and hands.\\n\",\n    \"        and \\\"neck\\\" not in j\\n\",\n    \"        and \\\"wrist\\\" not in j\\n\",\n    \"        and \\\"thumb\\\" not in j\\n\",\n    \"        and \\\"index\\\" not in j\\n\",\n    \"        and \\\"middle\\\" not in j\\n\",\n    \"        and \\\"ring\\\" not in j\\n\",\n    \"        and \\\"pinkie\\\" not in j\\n\",\n    \"        # Remove upper body.\\n\",\n    \"        and \\\"torso\\\" not in j and \\\"elbow\\\" not in j and \\\"shoulder\\\" not in j\\n\",\n    \"    ),\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"RLvAit_i2ZiA\",\n    \"outputId\": \"ea3954af-b9b9-46ac-d9cb-20b99b1eac94\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Print model quantities.\\n\",\n    \"print(f\\\"Model name: {model.name()}\\\")\\n\",\n    \"print(f\\\"Number of links: {model.number_of_links()}\\\")\\n\",\n    \"print(f\\\"Number of joints: {model.number_of_joints()}\\\")\\n\",\n    \"\\n\",\n    \"print()\\n\",\n    \"print(f\\\"Links:\\\\n{model.link_names()}\\\")\\n\",\n    \"\\n\",\n    \"print()\\n\",\n    \"print(f\\\"Joints:\\\\n{model.joint_names()}\\\")\\n\",\n    \"\\n\",\n    \"print()\\n\",\n    \"print(f\\\"Frames:\\\\n{model.frame_names()}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"Xp8V5on5twZ8\",\n    \"outputId\": \"cc1564db-ae91-4dba-92c9-b8b87bd65f10\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Create a random data object from the reduced model.\\n\",\n    \"data = js.data.random_model_data(model=model)\\n\",\n    \"\\n\",\n    \"# Print the default state.\\n\",\n    \"W_H_B, s = data.generalized_position\\n\",\n    \"ν = data.generalized_velocity\\n\",\n    \"\\n\",\n    \"print(f\\\"W_H_B: shape={W_H_B.shape}\\\\n{W_H_B}\\\\n\\\")\\n\",\n    \"print(f\\\"s: shape={s.shape}\\\\n{s}\\\\n\\\")\\n\",\n    \"print(f\\\"ν: shape={ν.shape}\\\\n{ν}\\\\n\\\")  # noqa: RUF001\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"XLx3sv9VtwZ9\",\n    \"outputId\": \"28f5f070-e37e-464e-d84e-2944cfdc28dc\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Create a random link forces matrix.\\n\",\n    \"link_forces = jax.random.uniform(\\n\",\n    \"    minval=-10.0,\\n\",\n    \"    maxval=10.0,\\n\",\n    \"    shape=(model.number_of_links(), 6),\\n\",\n    \"    key=jax.random.PRNGKey(0),\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Create a random joint force references vector.\\n\",\n    \"# Note that these are called 'references' because the actual joint forces that\\n\",\n    \"# are actuated might differ due to effects like joint friction.\\n\",\n    \"joint_force_references = jax.random.uniform(\\n\",\n    \"    minval=-10.0, maxval=10.0, shape=(model.dofs(),), key=jax.random.PRNGKey(0)\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Create the references object.\\n\",\n    \"references = js.references.JaxSimModelReferences.build(\\n\",\n    \"    model=model,\\n\",\n    \"    data=data,\\n\",\n    \"    link_forces=link_forces,\\n\",\n    \"    joint_force_references=joint_force_references,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print(f\\\"link_forces: shape={references.link_forces(model=model, data=data).shape}\\\")\\n\",\n    \"print(f\\\"joint_force_references: shape={references.joint_force_references(model=model).shape}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"AaG817vP4LfT\"\n   },\n   \"source\": [\n    \"## Robot Kinematics\\n\",\n    \"\\n\",\n    \"JaxSim offers functional APIs for computing kinematic quantities:\\n\",\n    \"\\n\",\n    \"- **`jaxsim.api.model`**: vectorized functions operating on the whole model.\\n\",\n    \"- **`jaxsim.api.link`**: functions operating on individual links.\\n\",\n    \"- **`jaxsim.api.frame`**: functions operating on individual frames. \\n\",\n    \"\\n\",\n    \"Due to JAX limitations on vectorizable data types, many APIs operate on indices instead of names. Since using indices can be error prone, JaxSim provides conversion functions for both links:\\n\",\n    \"\\n\",\n    \"- **jaxsim.api.link.names_to_idxs()**\\n\",\n    \"- **jaxsim.api.link.idxs_to_names()**\\n\",\n    \"\\n\",\n    \"and frames: \\n\",\n    \"\\n\",\n    \"- **jaxsim.api.frame.names_to_idxs()**\\n\",\n    \"- **jaxsim.api.frame.idxs_to_names()**\\n\",\n    \"\\n\",\n    \"We recommend using names whenever possible to avoid hard-to-trace errors.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"QxImwfZz7pz-\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Find the index of a link.\\n\",\n    \"link_name = \\\"l_ankle_2\\\"\\n\",\n    \"link_index = js.link.name_to_idx(model=model, link_name=link_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"C22Iqu2i4G-I\",\n    \"outputId\": \"94376151-177d-410f-f375-b7b8bd080992\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# @title Link Pose\\n\",\n    \"\\n\",\n    \"# Compute its pose w.r.t. the world frame through forward kinematics.\\n\",\n    \"W_H_L = js.link.transform(model=model, data=data, link_index=link_index)\\n\",\n    \"\\n\",\n    \"print(f\\\"Transform of '{link_name}': shape={W_H_L.shape}\\\\n{W_H_L}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"DnSpE_f97RkX\",\n    \"outputId\": \"a3f6b535-4ae5-49f4-8921-7fe4dda5debb\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# @title Link 6D Velocity\\n\",\n    \"\\n\",\n    \"# JaxSim allows to select the so-called representation of the frame velocity.\\n\",\n    \"L_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Body)\\n\",\n    \"LW_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Mixed)\\n\",\n    \"W_v_WL = js.link.velocity(model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Inertial)\\n\",\n    \"\\n\",\n    \"print(f\\\"Body-fixed velocity      L_v_WL={L_v_WL}\\\")\\n\",\n    \"print(f\\\"Mixed velocity:         LW_v_WL={LW_v_WL}\\\")\\n\",\n    \"print(f\\\"Inertial-fixed velocity: W_v_WL={W_v_WL}\\\")\\n\",\n    \"\\n\",\n    \"# These can also be computed passing through the link free-floating Jacobian.\\n\",\n    \"# This type of Jacobian has a input velocity representation that corresponds\\n\",\n    \"# the velocity representation of ν, and an output velocity representation that\\n\",\n    \"# corresponds to the velocity representation of the desired 6D velocity.\\n\",\n    \"\\n\",\n    \"# You can use the following context manager to easily switch between representations.\\n\",\n    \"with data.switch_velocity_representation(VelRepr.Body):\\n\",\n    \"\\n\",\n    \"    # Body-fixed generalized velocity.\\n\",\n    \"    B_ν = data.generalized_velocity\\n\",\n    \"\\n\",\n    \"    # Free-floating Jacobian accepting a body-fixed generalized velocity and\\n\",\n    \"    # returning an inertial-fixed link velocity.\\n\",\n    \"    W_J_WL_B = js.link.jacobian(\\n\",\n    \"        model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Inertial\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"# Now the following relation should hold.\\n\",\n    \"assert jnp.allclose(W_v_WL, W_J_WL_B @ B_ν)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"SSoziCShtwZ9\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Find the index of a frame.\\n\",\n    \"frame_name = \\\"l_foot_front\\\"\\n\",\n    \"frame_index = js.frame.name_to_idx(model=model, frame_name=frame_name)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"fVp_xP_1twZ9\",\n    \"outputId\": \"cfaa0569-d768-4708-c98c-a5867c056d04\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# @title Frame Pose\\n\",\n    \"\\n\",\n    \"# Compute its pose w.r.t. the world frame through forward kinematics.\\n\",\n    \"W_H_F = js.frame.transform(model=model, data=data, frame_index=frame_index)\\n\",\n    \"\\n\",\n    \"print(f\\\"Transform of '{frame_name}': shape={W_H_F.shape}\\\\n{W_H_F}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"QqaqxneEFYiW\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# @title Frame 6D Velocity\\n\",\n    \"\\n\",\n    \"# JaxSim allows to select the so-called representation of the frame velocity.\\n\",\n    \"F_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Body)\\n\",\n    \"FW_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Mixed)\\n\",\n    \"W_v_WF = js.frame.velocity(model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Inertial)\\n\",\n    \"\\n\",\n    \"print(f\\\"Body-fixed velocity      F_v_WF={F_v_WF}\\\")\\n\",\n    \"print(f\\\"Mixed velocity:         FW_v_WF={FW_v_WF}\\\")\\n\",\n    \"print(f\\\"Inertial-fixed velocity: W_v_WF={W_v_WF}\\\")\\n\",\n    \"\\n\",\n    \"# These can also be computed passing through the frame free-floating Jacobian.\\n\",\n    \"# This type of Jacobian has a input velocity representation that corresponds\\n\",\n    \"# the velocity representation of ν, and an output velocity representation that\\n\",\n    \"# corresponds to the velocity representation of the desired 6D velocity.\\n\",\n    \"\\n\",\n    \"# You can use the following context manager to easily switch between representations.\\n\",\n    \"with data.switch_velocity_representation(VelRepr.Body):\\n\",\n    \"\\n\",\n    \"    # Body-fixed generalized velocity.\\n\",\n    \"    B_ν = data.generalized_velocity\\n\",\n    \"\\n\",\n    \"    # Free-floating Jacobian accepting a body-fixed generalized velocity and\\n\",\n    \"    # returning an inertial-fixed link velocity.\\n\",\n    \"    W_J_WF_B = js.frame.jacobian(\\n\",\n    \"        model=model, data=data, frame_index=frame_index, output_vel_repr=VelRepr.Inertial\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"# Now the following relation should hold.\\n\",\n    \"assert jnp.allclose(W_v_WF, W_J_WF_B @ B_ν)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"d_vp6D74GoVZ\",\n    \"outputId\": \"798b9283-792e-4339-b56c-df2595fac974\"\n   },\n   \"source\": [\n    \"## Robot Dynamics\\n\",\n    \"\\n\",\n    \"JaxSim provides all the quantities involved in the equations of motion, restated here:\\n\",\n    \"\\n\",\n    \"$$\\n\",\n    \"M(\\\\mathbf{q}) \\\\dot{\\\\boldsymbol{\\\\nu}} + \\\\mathbf{h}(\\\\mathbf{q}, \\\\boldsymbol{\\\\nu}) = B \\\\boldsymbol{\\\\tau} + \\\\sum_{L_i \\\\in \\\\mathcal{L}} J_{W,L_i}^\\\\top(\\\\mathbf{q}) \\\\: \\\\mathbf{f}_i\\n\",\n    \".\\n\",\n    \"$$\\n\",\n    \"\\n\",\n    \"Specifically, it can compute:\\n\",\n    \"\\n\",\n    \"- $M(\\\\mathbf{q}) \\\\in \\\\mathbb{R}^{(6+n)\\\\times(6+n)}$: the mass matrix.\\n\",\n    \"- $\\\\mathbf{h}(\\\\mathbf{q}, \\\\boldsymbol{\\\\nu}) \\\\in \\\\mathbb{R}^{6+n}$: the vector of bias forces.\\n\",\n    \"- $B \\\\in \\\\mathbb{R}^{(6+n) \\\\times n}$ the joint selector matrix.\\n\",\n    \"- $J_{W,L} \\\\in \\\\mathbb{R}^{6 \\\\times (6+n)}$ the Jacobian of link $L$.\\n\",\n    \"\\n\",\n    \"Often, for convenience, link Jacobians are stacked together. Since JaxSim efficiently computes the Jacobians for all links, using the stacked version is recommended when needed:\\n\",\n    \"\\n\",\n    \"$$\\n\",\n    \"M(\\\\mathbf{q}) \\\\dot{\\\\boldsymbol{\\\\nu}} + \\\\mathbf{h}(\\\\mathbf{q}, \\\\boldsymbol{\\\\nu}) = B \\\\boldsymbol{\\\\tau} + J_{W,\\\\mathcal{L}}^\\\\top(\\\\mathbf{q}) \\\\: \\\\mathbf{f}_\\\\mathcal{L}\\n\",\n    \".\\n\",\n    \"$$\\n\",\n    \"\\n\",\n    \"Furthermore, there are applications that require unpacking the vector of bias forces as follow:\\n\",\n    \"\\n\",\n    \"$$\\n\",\n    \"\\\\mathbf{h}(\\\\mathbf{q}, \\\\boldsymbol{\\\\nu}) = C(\\\\mathbf{q}, \\\\boldsymbol{\\\\nu}) \\\\boldsymbol{\\\\nu} + \\\\mathbf{g}(\\\\mathbf{q})\\n\",\n    \",\\n\",\n    \"$$\\n\",\n    \"\\n\",\n    \"where:\\n\",\n    \"\\n\",\n    \"- $\\\\mathbf{g}(\\\\mathbf{q}) \\\\in \\\\mathbb{R}^{6+n}$: the vector of gravity forces.\\n\",\n    \"- $C(\\\\mathbf{q}, \\\\boldsymbol{\\\\nu}) \\\\in \\\\mathbb{R}^{(6+n)\\\\times(6+n)}$: the Coriolis matrix.\\n\",\n    \"\\n\",\n    \"Here below we report the functions to compute all these quantities. Note that all quantities depend on the active velocity representation of `data`. As it was done for the link velocity, it is possible to change the representation associated to all the computed quantities by operating within the corresponding context manager. Here below we consider the default representation of data.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"oOKJOVfsH4Ki\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"print(\\\"Velocity representation of data:\\\", data.velocity_representation, \\\"\\\\n\\\")\\n\",\n    \"\\n\",\n    \"# Compute the mass matrix.\\n\",\n    \"M = js.model.free_floating_mass_matrix(model=model, data=data)\\n\",\n    \"print(f\\\"M:   shape={M.shape}\\\")\\n\",\n    \"\\n\",\n    \"# Compute the vector of bias forces.\\n\",\n    \"h = js.model.free_floating_bias_forces(model=model, data=data)\\n\",\n    \"print(f\\\"h:   shape={h.shape}\\\")\\n\",\n    \"\\n\",\n    \"# Compute the vector of gravity forces.\\n\",\n    \"g = js.model.free_floating_gravity_forces(model=model, data=data)\\n\",\n    \"print(f\\\"g:   shape={g.shape}\\\")\\n\",\n    \"\\n\",\n    \"# Compute the Coriolis matrix.\\n\",\n    \"C = js.model.free_floating_coriolis_matrix(model=model, data=data)\\n\",\n    \"print(f\\\"C:   shape={C.shape}\\\")\\n\",\n    \"\\n\",\n    \"# Create a the joint selector matrix.\\n\",\n    \"B = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T\\n\",\n    \"print(f\\\"B:   shape={B.shape}\\\")\\n\",\n    \"\\n\",\n    \"# Compute the stacked tensor of link Jacobians.\\n\",\n    \"J = js.model.generalized_free_floating_jacobian(model=model, data=data)\\n\",\n    \"print(f\\\"J:   shape={J.shape}\\\")\\n\",\n    \"\\n\",\n    \"# Extract the joint forces from the references object.\\n\",\n    \"τ = references.joint_force_references(model=model)\\n\",\n    \"print(f\\\"τ:   shape={τ.shape}\\\")\\n\",\n    \"\\n\",\n    \"# Extract the link forces from the references object.\\n\",\n    \"f_L = references.link_forces(model=model, data=data)\\n\",\n    \"print(f\\\"f_L: shape={f_L.shape}\\\")\\n\",\n    \"\\n\",\n    \"# The following relation should hold.\\n\",\n    \"assert jnp.allclose(h, C @ ν + g)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"FlNo8dNWKKtu\",\n    \"outputId\": \"313e939b-f88f-4407-c9ee-b5b3b7443061\"\n   },\n   \"source\": [\n    \"### Forward Dynamics\\n\",\n    \"\\n\",\n    \"$$\\n\",\n    \"\\\\dot{\\\\boldsymbol{\\\\nu}} = \\\\text{FD}(\\\\mathbf{q}, \\\\boldsymbol{\\\\nu}, \\\\boldsymbol{\\\\tau}, \\\\mathbf{f}_{\\\\mathcal{L}})\\n\",\n    \"$$\\n\",\n    \"\\n\",\n    \"JaxSim provides two alternative methods to compute the forward dynamics:\\n\",\n    \"\\n\",\n    \"1. Operate on the quantities of the equations of motion.\\n\",\n    \"2. Call the recursive Articulated Body Algorithm (ABA).\\n\",\n    \"\\n\",\n    \"The physics engine provided by JaxSim exploits the efficient calculation of the forward dynamics with ABA for simulating the trajectories of the system dynamics.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"LXARuRu1Ly1K\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"ν̇_eom = jnp.linalg.pinv(M) @ (B @ τ - h + jnp.einsum(\\\"l6g,l6->g\\\", J, f_L))\\n\",\n    \"\\n\",\n    \"v̇_WB, s̈ = js.model.forward_dynamics_aba(\\n\",\n    \"    model=model, data=data, link_forces=f_L, joint_forces=joint_force_references\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"ν̇_aba = jnp.hstack([v̇_WB, s̈])\\n\",\n    \"print(f\\\"ν̇: shape={ν̇_aba.shape}\\\")  # noqa: RUF001\\n\",\n    \"\\n\",\n    \"# The following relation should hold.\\n\",\n    \"assert jnp.allclose(ν̇_eom, ν̇_aba)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"g5GOYXDnLySU\",\n    \"outputId\": \"ad4ce77d-d06f-473a-9c32-040680d76aa5\"\n   },\n   \"source\": [\n    \"### Inverse Dynamics\\n\",\n    \"\\n\",\n    \"$$\\n\",\n    \"(\\\\boldsymbol{\\\\tau}, \\\\, \\\\mathbf{f}_B) = \\\\text{ID}(\\\\mathbf{q}, \\\\boldsymbol{\\\\nu}, \\\\dot{\\\\boldsymbol{\\\\nu}}, \\\\mathbf{f}_{\\\\mathcal{L}})\\n\",\n    \"$$\\n\",\n    \"\\n\",\n    \"JaxSim offers two methods to compute inverse dynamics:\\n\",\n    \"\\n\",\n    \"- Directly use the quantities from the equations of motion.\\n\",\n    \"- Use the Recursive Newton-Euler Algorithm (RNEA).\\n\",\n    \"\\n\",\n    \"Unlike many other implementations, JaxSim's RNEA for floating-base systems is the true inverse of $\\\\text{FD}$. It also computes the 6D force applied to the base link that generates the base acceleration.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"UTae5MjhaP2H\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"f_B, τ_rnea = js.model.inverse_dynamics(\\n\",\n    \"    model=model,\\n\",\n    \"    data=data,\\n\",\n    \"    base_acceleration=v̇_WB,\\n\",\n    \"    joint_accelerations=s̈,\\n\",\n    \"    # To check that f_B works, let's remove the force applied\\n\",\n    \"    # to the base link from the link forces.\\n\",\n    \"    link_forces=f_L.at[0].set(jnp.zeros(6))\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print(f\\\"f_B:    shape={f_B.shape}\\\")\\n\",\n    \"print(f\\\"τ_rnea: shape={τ_rnea.shape}\\\")\\n\",\n    \"\\n\",\n    \"# The following relations should hold.\\n\",\n    \"assert jnp.allclose(τ_rnea, τ)\\n\",\n    \"assert jnp.allclose(f_B, link_forces[0])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"gYZ1jK1Neg1H\",\n    \"outputId\": \"0de79770-1e18-4027-bb47-5713bc1b4a72\"\n   },\n   \"source\": [\n    \"### Centroidal Dynamics\\n\",\n    \"\\n\",\n    \"Centroidal dynamics is a useful simplification often employed in planning and control applications. It represents the dynamics projected onto a mixed frame associated with the center of mass (CoM):\\n\",\n    \"\\n\",\n    \"$$\\n\",\n    \"G = G[W] = ({}^W \\\\mathbf{p}_{\\\\text{CoM}}, [W])\\n\",\n    \".\\n\",\n    \"$$\\n\",\n    \"\\n\",\n    \"The governing equations for centroidal dynamics take into account the 6D centroidal momentum:\\n\",\n    \"\\n\",\n    \"$$\\n\",\n    \"{}_G \\\\mathbf{h} =\\n\",\n    \"\\\\begin{bmatrix}\\n\",\n    \"{}_G \\\\mathbf{h}^l \\\\\\\\ {}_G \\\\mathbf{h}^\\\\omega\\n\",\n    \"\\\\end{bmatrix} =\\n\",\n    \"\\\\begin{bmatrix}\\n\",\n    \"m \\\\, {}^W \\\\dot{\\\\mathbf{p}}_\\\\text{CoM} \\\\\\\\ {}_G \\\\mathbf{h}^\\\\omega\\n\",\n    \"\\\\end{bmatrix}\\n\",\n    \"\\\\in \\\\mathbb{R}^6\\n\",\n    \".\\n\",\n    \"$$\\n\",\n    \"\\n\",\n    \"The equations of centroidal dynamics can be expressed as:\\n\",\n    \"\\n\",\n    \"$$\\n\",\n    \"{}_G \\\\dot{\\\\mathbf{h}} =\\n\",\n    \"m \\\\,\\n\",\n    \"\\\\begin{bmatrix}\\n\",\n    \"{}^W \\\\mathbf{g} \\\\\\\\ \\\\mathbf{0}_3\\n\",\n    \"\\\\end{bmatrix} +\\n\",\n    \"\\\\sum_{C_i \\\\in \\\\mathcal{C}} {}_G \\\\mathbf{X}^{C_i} \\\\, {}_{C_i} \\\\mathbf{f}_i\\n\",\n    \".\\n\",\n    \"$$\\n\",\n    \"\\n\",\n    \"While centroidal dynamics can function independently by considering the total mass $m \\\\in \\\\mathbb{R}$ of the robot and the transformations for 6D contact forces ${}_G \\\\mathbf{X}^{C_i}$ corresponding to the pose ${}^G \\\\mathbf{H}_{C_i} \\\\in \\\\text{SE}(3)$ of the contact frames, advanced kino-dynamic methods may require a relationship between full kinematics and centroidal dynamics. This is typically achieved through the _Centroidal Momentum Matrix_ (also known as the _centroidal momentum Jacobian_):\\n\",\n    \"\\n\",\n    \"$$\\n\",\n    \"{}_G \\\\mathbf{h} = J_\\\\text{CMM}(\\\\mathbf{q}) \\\\, \\\\boldsymbol{\\\\nu}\\n\",\n    \".\\n\",\n    \"$$\\n\",\n    \"\\n\",\n    \"JaxSim offers APIs to compute all these quantities (and many more) in the `jaxsim.api.com` package.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"rrSfxp8lh9YZ\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Number of contact points.\\n\",\n    \"n_cp = len(model.kin_dyn_parameters.contact_parameters.body)\\n\",\n    \"print(\\\"Number of contact points:\\\", n_cp, \\\"\\\\n\\\")\\n\",\n    \"\\n\",\n    \"# Compute the centroidal momentum.\\n\",\n    \"J_CMM = js.com.centroidal_momentum_jacobian(model=model, data=data)\\n\",\n    \"G_h = J_CMM @ ν\\n\",\n    \"print(f\\\"G_h:    shape={G_h.shape}\\\")\\n\",\n    \"print(f\\\"J_CMM:  shape={J_CMM.shape}\\\")\\n\",\n    \"\\n\",\n    \"# The following relation should hold.\\n\",\n    \"assert jnp.allclose(G_h, js.com.centroidal_momentum(model=model, data=data))\\n\",\n    \"\\n\",\n    \"# If we consider all contact points of the model as active\\n\",\n    \"# (discourages since they might be too many), the 6D transforms of\\n\",\n    \"# collidable points can be computed as follows:\\n\",\n    \"W_H_C = js.contact.transforms(model=model, data=data)\\n\",\n    \"\\n\",\n    \"# Compute the pose of the G frame.\\n\",\n    \"W_p_CoM = js.com.com_position(model=model, data=data)\\n\",\n    \"G_H_W = jaxsim.math.Transform.inverse(jnp.eye(4).at[0:3, 3].set(W_p_CoM))\\n\",\n    \"\\n\",\n    \"# Convert from SE(3) to the transforms for 6D forces.\\n\",\n    \"G_Xf_C = jax.vmap(\\n\",\n    \"    lambda W_H_Ci: jaxsim.math.Adjoint.from_transform(\\n\",\n    \"        transform=G_H_W @ W_H_Ci, inverse=True\\n\",\n    \"    )\\n\",\n    \")(W_H_C)\\n\",\n    \"print(f\\\"G_Xf_C: shape={G_Xf_C.shape}\\\")\\n\",\n    \"\\n\",\n    \"# Let's create random 3D linear forces applied to the contact points.\\n\",\n    \"C_fl = jax.random.uniform(\\n\",\n    \"    minval=-10.0,\\n\",\n    \"    maxval=10.0,\\n\",\n    \"    shape=(n_cp, 3),\\n\",\n    \"    key=jax.random.PRNGKey(0),\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Compute the 3D gravity vector and the total mass of the robot.\\n\",\n    \"m = js.model.total_mass(model=model)\\n\",\n    \"\\n\",\n    \"# The centroidal dynamics can be computed as follows.\\n\",\n    \"G_ḣ = 0\\n\",\n    \"G_ḣ += m * jnp.hstack([0, 0, model.gravity, 0, 0, 0])\\n\",\n    \"G_ḣ += jnp.einsum(\\\"c66,c6->6\\\", G_Xf_C, jnp.hstack([C_fl, jnp.zeros_like(C_fl)]))\\n\",\n    \"print(f\\\"G_ḣ:    shape={G_ḣ.shape}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"Ot6HePB_twaE\",\n    \"outputId\": \"02a6abae-257e-45ee-e9de-6a607cdbeb9a\"\n   },\n   \"source\": [\n    \"## Contact Frames\\n\",\n    \"\\n\",\n    \"Many control and planning applications require projecting the floating-base dynamics into the contact space or computing quantities related to active contact points, such as enforcing holonomic constraints.\\n\",\n    \"\\n\",\n    \"The underlying theory for these applications becomes clearer in a mixed representation. Specifically, the position, linear velocity, and linear acceleration of contact points in their corresponding mixed frame align with the numerical derivatives of their coordinate vectors.\\n\",\n    \"\\n\",\n    \"Key methodologies in this area may involve the Delassus matrix:\\n\",\n    \"\\n\",\n    \"$$\\n\",\n    \"\\\\Psi(\\\\mathbf{q}) = J_{W,C}(\\\\mathbf{q}) \\\\, M(\\\\mathbf{q})^{-1} \\\\, J_{W,C}^T(\\\\mathbf{q})\\n\",\n    \"$$\\n\",\n    \"\\n\",\n    \"or the linear acceleration of a contact point:\\n\",\n    \"\\n\",\n    \"$$\\n\",\n    \"{}^W \\\\ddot{\\\\mathbf{p}}_C = \\\\frac{\\\\text{d} (J^l_{W,C} \\\\boldsymbol{\\\\nu})}{\\\\text{d}t}\\n\",\n    \"= \\\\dot{J}^l_{W,C} \\\\boldsymbol{\\\\nu} + J^l_{W,C} \\\\dot{\\\\boldsymbol{\\\\nu}}\\n\",\n    \".\\n\",\n    \"$$\\n\",\n    \"\\n\",\n    \"JaxSim offers APIs to compute all these quantities (and many more) in the `jaxsim.api.contact` package.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"LITRC3STliKR\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"with (\\n\",\n    \"    data.switch_velocity_representation(VelRepr.Mixed),\\n\",\n    \"    references.switch_velocity_representation(VelRepr.Mixed),\\n\",\n    \"):\\n\",\n    \"\\n\",\n    \"    # Compute the mixed generalized velocity.\\n\",\n    \"    BW_ν = data.generalized_velocity\\n\",\n    \"\\n\",\n    \"    # Compute the mixed generalized acceleration.\\n\",\n    \"    BW_ν̇ = jnp.hstack(\\n\",\n    \"        js.model.forward_dynamics(\\n\",\n    \"            model=model,\\n\",\n    \"            data=data,\\n\",\n    \"            link_forces=references.link_forces(model=model, data=data),\\n\",\n    \"            joint_forces=references.joint_force_references(model=model),\\n\",\n    \"        )\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    # Compute the mass matrix in mixed representation.\\n\",\n    \"    BW_M = js.model.free_floating_mass_matrix(model=model, data=data)\\n\",\n    \"\\n\",\n    \"    # Compute the contact Jacobian and its derivative.\\n\",\n    \"    Jl_WC = js.contact.jacobian(model=model, data=data)[:, 0:3, :]\\n\",\n    \"    J̇l_WC = js.contact.jacobian_derivative(model=model, data=data)[:, 0:3, :]\\n\",\n    \"\\n\",\n    \"# Compute the Delassus matrix.\\n\",\n    \"Ψ = jnp.vstack(Jl_WC) @ jnp.linalg.lstsq(BW_M, jnp.vstack(Jl_WC).T)[0]\\n\",\n    \"print(f\\\"Ψ:     shape={Ψ.shape}\\\")\\n\",\n    \"\\n\",\n    \"# Compute the transforms of the mixed frames implicitly associated\\n\",\n    \"# to each collidable point.\\n\",\n    \"W_H_C = js.contact.transforms(model=model, data=data)\\n\",\n    \"print(f\\\"W_H_C: shape={W_H_C.shape}\\\")\\n\",\n    \"\\n\",\n    \"# Compute the linear velocity of the collidable points.\\n\",\n    \"with data.switch_velocity_representation(VelRepr.Mixed):\\n\",\n    \"    W_ṗ_B = js.contact.collidable_point_velocities(model=model, data=data)[:, 0:3]\\n\",\n    \"    print(f\\\"W_ṗ_B: shape={W_ṗ_B.shape}\\\")\\n\",\n    \"\\n\",\n    \"# Compute the linear acceleration of the collidable points.\\n\",\n    \"W_p̈_C = 0\\n\",\n    \"W_p̈_C += jnp.einsum(\\\"c3g,g->c3\\\", J̇l_WC, BW_ν)\\n\",\n    \"W_p̈_C += jnp.einsum(\\\"c3g,g->c3\\\", Jl_WC, BW_ν̇)\\n\",\n    \"print(f\\\"W_p̈_C: shape={W_p̈_C.shape}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"LITRC3STliKR\"\n   },\n   \"source\": [\n    \"## Conclusions\\n\",\n    \"\\n\",\n    \"This notebook provided an overview of the main APIs in JaxSim for its use as a multibody dynamics library. Here are a few key points to remember:\\n\",\n    \"\\n\",\n    \"- Explore all the modules in the `jaxsim.api` package to discover the full range of APIs available. Many more functionalities exist beyond what was covered in this notebook.\\n\",\n    \"- All APIs follow a functional approach, consistent with the JAX programming style.\\n\",\n    \"- This functional design allows for easy application of `jax.vmap` to execute functions in parallel on hardware accelerators.\\n\",\n    \"- Since the entire multibody dynamics library is built with JAX, it natively supports `jax.grad`, `jax.jacfwd`, and `jax.jacrev` transformations, enabling automatic differentiation through complex logic without additional effort.\\n\",\n    \"\\n\",\n    \"Have fun!\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"accelerator\": \"GPU\",\n  \"colab\": {\n   \"gpuClass\": \"premium\",\n   \"private_outputs\": true,\n   \"provenance\": [],\n   \"toc_visible\": true\n  },\n  \"kernelspec\": {\n   \"display_name\": \"comodo_jaxsim\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.12.8\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "examples/jaxsim_as_physics_engine.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"H-WgcgGQaTG7\"\n   },\n   \"source\": [\n    \"# JaxSim as a hardware-accelerated parallel physics engine\\n\",\n    \"\\n\",\n    \"This notebook shows how to use the key APIs to load a robot model and simulate multiple trajectories simultaneously.\\n\",\n    \"\\n\",\n    \"<a target=\\\"_blank\\\" href=\\\"https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_as_physics_engine.ipynb\\\">\\n\",\n    \"  <img src=\\\"https://colab.research.google.com/assets/colab-badge.svg\\\" alt=\\\"Open In Colab\\\"/>\\n\",\n    \"</a>\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"SgOSnrSscEkt\"\n   },\n   \"source\": [\n    \"## Prepare the environment\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"fdqvAqMDaTG9\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# @title Imports and setup\\n\",\n    \"import os\\n\",\n    \"import sys\\n\",\n    \"from IPython.display import clear_output\\n\",\n    \"\\n\",\n    \"IS_COLAB = \\\"google.colab\\\" in sys.modules\\n\",\n    \"\\n\",\n    \"# Install JAX and Gazebo\\n\",\n    \"if IS_COLAB:\\n\",\n    \"    !{sys.executable} -m pip install --pre -qU jaxsim\\n\",\n    \"    !apt install -qq lsb-release wget gnupg\\n\",\n    \"    !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\\n\",\n    \"    !echo \\\"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\\\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\\n\",\n    \"    !apt -qq update\\n\",\n    \"    !apt install -qq --no-install-recommends libsdformat13 gz-tools2\\n\",\n    \"\\n\",\n    \"    clear_output()\\n\",\n    \"\\n\",\n    \"# Set environment variable to avoid GPU out of memory errors\\n\",\n    \"%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# ================\\n\",\n    \"# Notebook imports\\n\",\n    \"# ================\\n\",\n    \"\\n\",\n    \"import jax\\n\",\n    \"import jax.numpy as jnp\\n\",\n    \"import jaxsim.api as js\\n\",\n    \"from jaxsim import logging\\n\",\n    \"import pathlib\\n\",\n    \"\\n\",\n    \"logging.set_logging_level(logging.LoggingLevel.WARNING)\\n\",\n    \"print(f\\\"Running on {jax.devices()}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"NqjuZKvOaTG_\"\n   },\n   \"source\": [\n    \"## Prepare the simulation\\n\",\n    \"\\n\",\n    \"JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. In this example, we will load the [ergoCub][ergocub] model urdf.\\n\",\n    \"\\n\",\n    \"[sdformat]: http://sdformat.org/\\n\",\n    \"[urdf]: http://wiki.ros.org/urdf/\\n\",\n    \"[ergocub]: https://ergocub.eu/\\n\",\n    \"[rod]: https://github.com/gbionics/rod\\n\",\n    \"\\n\",\n    \"### Create the model and its data\\n\",\n    \" To define a simulation we need two main objects:\\n\",\n    \"\\n\",\n    \"- `model`: an object that defines the dynamics of the system.\\n\",\n    \"- `data`: an object that contains the state of the system.\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\\n\",\n    \"To see the advanced usage, check the advanced example, where you will see how to pass explicitly an integrator class and state to the `model` object and how to change the contact model.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Create the model \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"etQ577cFaTHA\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"#  Create the JaxSim model.\\n\",\n    \"try:\\n\",\n    \"    os.environ[\\\"ROBOT_DESCRIPTION_COMMIT\\\"] = \\\"v0.7.7\\\"\\n\",\n    \"\\n\",\n    \"    import robot_descriptions.ergocub_description\\n\",\n    \"\\n\",\n    \"finally:\\n\",\n    \"    _ = os.environ.pop(\\\"ROBOT_DESCRIPTION_COMMIT\\\", None)\\n\",\n    \"\\n\",\n    \"model_description_path = pathlib.Path(\\n\",\n    \"    robot_descriptions.ergocub_description.URDF_PATH.replace(\\n\",\n    \"        \\\"ergoCubSN002\\\", \\\"ergoCubSN001\\\"\\n\",\n    \"    )\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"clear_output()\\n\",\n    \"\\n\",\n    \"full_model = js.model.JaxSimModel.build_from_model_description(\\n\",\n    \"    model_description=model_description_path,\\n\",\n    \"    time_step=0.0001,\\n\",\n    \"    is_urdf=True\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"joints_list = tuple(('l_shoulder_pitch', 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow',\\n\",\n    \"               'r_shoulder_pitch', 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow',\\n\",\n    \"               'l_hip_pitch', 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',\\n\",\n    \"               'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll'))\\n\",\n    \"\\n\",\n    \"model = js.model.reduce(\\n\",\n    \"    model=full_model,\\n\",\n    \"    considered_joints=joints_list\\n\",\n    \")\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Create the data object \\n\",\n    \"\\n\",\n    \"The data object is never changed by reference. Anytime you call a method aimed at modifying data, like `reset_base_position`, a new data object will be returned with the updated attributes while the original data will not be changed.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Create the data of a single model.\\n\",\n    \"data = js.data.JaxSimModelData.build(model=model, base_position=jnp.array([0.0, 0.0, 1.0]))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Simulation\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Create a random JAX key.\\n\",\n    \"\\n\",\n    \"key = jax.random.PRNGKey(seed=0)\\n\",\n    \"\\n\",\n    \"# Initialize the simulated time.\\n\",\n    \"T = jnp.arange(start=0, stop=0.3, step=model.time_step)\\n\",\n    \"\\n\",\n    \"# Simulate\\n\",\n    \"for _t in T:\\n\",\n    \"    data = js.model.step(\\n\",\n    \"        model=model,\\n\",\n    \"        data=data,\\n\",\n    \"        link_forces=None,\\n\",\n    \"        joint_force_references=None,\\n\",\n    \"    )\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Vectorized simulation \\n\",\n    \"\\n\",\n    \"We will now vectorize the simulation on batched data using `jax.vmap`\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# first we have to vmap the function\\n\",\n    \"\\n\",\n    \"import functools\\n\",\n    \"from typing import Any\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"@jax.jit\\n\",\n    \"def step_single(\\n\",\n    \"    model: js.model.JaxSimModel,\\n\",\n    \"    data: js.data.JaxSimModelData,\\n\",\n    \") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\\n\",\n    \"\\n\",\n    \"    # Close step over static arguments.\\n\",\n    \"    return js.model.step(\\n\",\n    \"        model=model,\\n\",\n    \"        data=data,\\n\",\n    \"        link_forces=None,\\n\",\n    \"        joint_force_references=None,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"@jax.jit\\n\",\n    \"@functools.partial(jax.vmap, in_axes=(None, 0))\\n\",\n    \"def step_parallel(\\n\",\n    \"    model: js.model.JaxSimModel,\\n\",\n    \"    data: js.data.JaxSimModelData,\\n\",\n    \") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\\n\",\n    \"\\n\",\n    \"    return step_single(\\n\",\n    \"        model=model, data=data\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Then we have to create the vector of initial state\\n\",\n    \"batch_size = 5\\n\",\n    \"data_batch_t0 = jax.vmap(\\n\",\n    \"    lambda pos:  js.data.JaxSimModelData.build(model=model, base_position=pos))(jnp.tile(jnp.array([0.0, 0.0, 1.0]), (batch_size, 1)))\\n\",\n    \"\\n\",\n    \"data = data_batch_t0\\n\",\n    \"for _t in T:\\n\",\n    \"    data = step_parallel(model, data)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"accelerator\": \"GPU\",\n  \"colab\": {\n   \"gpuClass\": \"premium\",\n   \"private_outputs\": true,\n   \"provenance\": [],\n   \"toc_visible\": true\n  },\n  \"kernelspec\": {\n   \"display_name\": \"jaxsim\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.13.1\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "examples/jaxsim_as_physics_engine_advanced.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"H-WgcgGQaTG7\"\n   },\n   \"source\": [\n    \"# JaxSim as a hardware-accelerated parallel physics engine-advanced usage\\n\",\n    \"\\n\",\n    \"JaxSim is developed to optimize synthetic data generation by sampling trajectories using hardware accelerators such as GPUs and TPUs.\\n\",\n    \"\\n\",\n    \"In this notebook, you'll learn how to use the key APIs to load a simple robot model (a sphere) and simulate multiple trajectories in parallel on GPUs.\\n\",\n    \"\\n\",\n    \"<a target=\\\"_blank\\\" href=\\\"https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_as_physics_engine_advanced.ipynb\\\">\\n\",\n    \"  <img src=\\\"https://colab.research.google.com/assets/colab-badge.svg\\\" alt=\\\"Open In Colab\\\"/>\\n\",\n    \"</a>\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"SgOSnrSscEkt\"\n   },\n   \"source\": [\n    \"## Prepare the environment\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"fdqvAqMDaTG9\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# @title Imports and setup\\n\",\n    \"import sys\\n\",\n    \"from IPython.display import clear_output\\n\",\n    \"\\n\",\n    \"IS_COLAB = \\\"google.colab\\\" in sys.modules\\n\",\n    \"\\n\",\n    \"# Install JAX and Gazebo\\n\",\n    \"if IS_COLAB:\\n\",\n    \"    !{sys.executable} -m pip install --pre -qU jaxsim[viz]\\n\",\n    \"    !apt install -qq lsb-release wget gnupg\\n\",\n    \"    !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\\n\",\n    \"    !echo \\\"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\\\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\\n\",\n    \"    !apt -qq update\\n\",\n    \"    !apt install -qq --no-install-recommends libsdformat13 gz-tools2\\n\",\n    \"\\n\",\n    \"    clear_output()\\n\",\n    \"\\n\",\n    \"# Set environment variable to avoid GPU out of memory errors\\n\",\n    \"%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false\\n\",\n    \"\\n\",\n    \"# ================\\n\",\n    \"# Notebook imports\\n\",\n    \"# ================\\n\",\n    \"\\n\",\n    \"import os\\n\",\n    \"\\n\",\n    \"if sys.platform == 'darwin':\\n\",\n    \"    os.environ[\\\"MUJOCO_GL\\\"] = \\\"glfw\\\"\\n\",\n    \"else:\\n\",\n    \"    os.environ[\\\"MUJOCO_GL\\\"] = \\\"egl\\\"\\n\",\n    \"\\n\",\n    \"import jax\\n\",\n    \"\\n\",\n    \"import jax.numpy as jnp\\n\",\n    \"import jaxsim.api as js\\n\",\n    \"import rod\\n\",\n    \"from jaxsim import logging\\n\",\n    \"from rod.builder.primitives import SphereBuilder\\n\",\n    \"\\n\",\n    \"logging.set_logging_level(logging.LoggingLevel.WARNING)\\n\",\n    \"print(f\\\"Running on {jax.devices()}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"QtCCUhdpdGFH\"\n   },\n   \"source\": [\n    \"## Prepare the simulation\\n\",\n    \"\\n\",\n    \"JaxSim supports loading robot descriptions from both [SDF][sdformat] and [URDF][urdf] files. This is done using the [`gbionics/rod`][rod] library, which processes these formats.\\n\",\n    \"\\n\",\n    \"The `rod` library also allows creating in-memory models that can be serialized to SDF or URDF. We'll use this functionality to build a sphere model, which will later be used to create the JaxSim model.\\n\",\n    \"\\n\",\n    \"[sdformat]: http://sdformat.org/\\n\",\n    \"[urdf]: http://wiki.ros.org/urdf/\\n\",\n    \"[rod]: https://github.com/gbionics/rod\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"cellView\": \"form\",\n    \"id\": \"0emoMQhCaTG_\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# @title Create the model description of a sphere\\n\",\n    \"\\n\",\n    \"# Create a SDF model.\\n\",\n    \"# The builder takes care to compute the right inertia tensor for you.\\n\",\n    \"rod_sdf = rod.Sdf(\\n\",\n    \"    version=\\\"1.7\\\",\\n\",\n    \"    model=SphereBuilder(radius=0.10, mass=1.0, name=\\\"sphere\\\")\\n\",\n    \"    .build_model()\\n\",\n    \"    .add_link()\\n\",\n    \"    .add_inertial()\\n\",\n    \"    .add_visual()\\n\",\n    \"    .add_collision()\\n\",\n    \"    .build(),\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Rod allows to update the frames w.r.t. the poses are expressed.\\n\",\n    \"rod_sdf.model.switch_frame_convention(\\n\",\n    \"    frame_convention=rod.FrameConvention.Urdf, explicit_frames=True\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Serialize the model to a SDF string.\\n\",\n    \"model_sdf_string = rod_sdf.serialize(pretty=True)\\n\",\n    \"print(model_sdf_string)\\n\",\n    \"\\n\",\n    \"# JaxSim currently only supports collisions between points attached to bodies\\n\",\n    \"# and a ground surface modeled as a heightmap sampled from a smooth function.\\n\",\n    \"# While this approach is universal as it applies to generic meshes, the number\\n\",\n    \"# of considered points greatly affects the performance. Spheres, by default,\\n\",\n    \"# are discretized with 250 points. It's too much for this simple example.\\n\",\n    \"# This number can be decreased with the following environment variable.\\n\",\n    \"os.environ[\\\"JAXSIM_COLLISION_SPHERE_POINTS\\\"] = \\\"50\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"NqjuZKvOaTG_\"\n   },\n   \"source\": [\n    \"### Create the model and its data\\n\",\n    \"\\n\",\n    \"JAXsim offers a simple functional API in order to interact in a memory-efficient way with the simulation. Four main elements are used to define a simulation:\\n\",\n    \"\\n\",\n    \"- `model`: an object that defines the dynamics of the system.\\n\",\n    \"- `data`: an object that contains the state of the system.\\n\",\n    \"- `integrator` *(Optional)*: an object that defines the integration method.\\n\",\n    \"- `integrator_metadata` *(Optional)*: an object that contains the state of the integrator.\\n\",\n    \"\\n\",\n    \"The `JaxSimModel` object contains the simulation time step, the integrator and the contact model.\\n\",\n    \"In this example, we will explicitly pass an integrator class to the `model` object and we will use the default `SoftContacts` contact model.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"etQ577cFaTHA\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Create the JaxSim model.\\n\",\n    \"# This is shared among all the parallel instances.\\n\",\n    \"model = js.model.JaxSimModel.build_from_model_description(\\n\",\n    \"    model_description=model_sdf_string,\\n\",\n    \"    time_step=0.001,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Create the data of a single model.\\n\",\n    \"# We will create a vectorized instance later.\\n\",\n    \"data_single = js.data.JaxSimModelData.zero(model=model)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"o86Teq5piVGj\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Initialize the simulated time.\\n\",\n    \"T = jnp.arange(start=0, stop=1.0, step=model.time_step)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"V6IeD2B3m4F0\"\n   },\n   \"source\": [\n    \"## Sample a batch of trajectories in parallel\\n\",\n    \"\\n\",\n    \"With the provided resources, you can step through an open-loop trajectory on a single model using `jaxsim.api.model.step`.\\n\",\n    \"\\n\",\n    \"In this notebook, we'll focus on running parallel steps. We'll use JAX's automatic vectorization to apply the step function to batched data.\\n\",\n    \"\\n\",\n    \"Note that these parallel simulations are independent — models don't interact, so there's no need to avoid initial collisions.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"vtEn0aIzr_2j\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# @title Generate batched initial data\\n\",\n    \"\\n\",\n    \"# Create a random JAX key.\\n\",\n    \"key = jax.random.PRNGKey(seed=0)\\n\",\n    \"\\n\",\n    \"# Split subkeys for sampling random initial data.\\n\",\n    \"batch_size = 9\\n\",\n    \"row_length = int(jnp.sqrt(batch_size))\\n\",\n    \"row_dist = 0.3 * row_length\\n\",\n    \"key, *subkeys = jax.random.split(key=key, num=batch_size + 1)\\n\",\n    \"\\n\",\n    \"# Create the batched data by sampling the height from [0.5, 0.6] meters.\\n\",\n    \"data_batch_t0 = jax.vmap(\\n\",\n    \"    lambda key: js.data.random_model_data(\\n\",\n    \"        model=model,\\n\",\n    \"        key=key,\\n\",\n    \"        base_pos_bounds=([0, 0, 0.3], [0, 0, 1.2]),\\n\",\n    \"        base_vel_lin_bounds=(0, 0),\\n\",\n    \"        base_vel_ang_bounds=(0, 0),\\n\",\n    \"    )\\n\",\n    \")(jnp.vstack(subkeys))\\n\",\n    \"\\n\",\n    \"x, y = jnp.meshgrid(\\n\",\n    \"    jnp.linspace(-row_dist, row_dist, num=row_length),\\n\",\n    \"    jnp.linspace(-row_dist, row_dist, num=row_length),\\n\",\n    \")\\n\",\n    \"xy_coordinate = jnp.stack([x.flatten(), y.flatten()], axis=-1)\\n\",\n    \"\\n\",\n    \"# Reset the x and y position to a grid.\\n\",\n    \"data_batch_t0 = data_batch_t0.replace(\\n\",\n    \"    model=model,\\n\",\n    \"    base_position=data_batch_t0.base_position.at[:, :2].set(xy_coordinate),\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"print(\\\"W_p_B(t0)=\\\\n\\\", data_batch_t0.base_position[0:10])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"0tQPfsl6uxHm\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# @title Create parallel step function\\n\",\n    \"\\n\",\n    \"import functools\\n\",\n    \"from typing import Any\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"@jax.jit\\n\",\n    \"def step_single(\\n\",\n    \"    model: js.model.JaxSimModel,\\n\",\n    \"    data: js.data.JaxSimModelData,\\n\",\n    \") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\\n\",\n    \"\\n\",\n    \"    # Close step over static arguments.\\n\",\n    \"    return js.model.step(\\n\",\n    \"        model=model,\\n\",\n    \"        data=data,\\n\",\n    \"        link_forces=None,\\n\",\n    \"        joint_force_references=None,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"@jax.jit\\n\",\n    \"@functools.partial(jax.vmap, in_axes=(None, 0))\\n\",\n    \"def step_parallel(\\n\",\n    \"    model: js.model.JaxSimModel,\\n\",\n    \"    data: js.data.JaxSimModelData,\\n\",\n    \") -> tuple[js.data.JaxSimModelData, dict[str, Any]]:\\n\",\n    \"\\n\",\n    \"    return step_single(\\n\",\n    \"        model=model, data=data\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# The first run will be slow since JAX needs to JIT-compile the functions.\\n\",\n    \"_ = step_single(model, data_single)\\n\",\n    \"_ = step_parallel(model, data_batch_t0)\\n\",\n    \"\\n\",\n    \"# Benchmark the execution of a single step.\\n\",\n    \"print(\\\"\\\\nSingle simulation step:\\\")\\n\",\n    \"%timeit step_single(model, data_single)\\n\",\n    \"\\n\",\n    \"# On hardware accelerators, there's a range of batch_size values where\\n\",\n    \"# increasing the number of parallel instances doesn't affect computation time.\\n\",\n    \"# This range depends on the GPU/TPU specifications.\\n\",\n    \"print(f\\\"\\\\nParallel simulation steps (batch_size={batch_size} on {jax.devices()[0]}):\\\")\\n\",\n    \"%timeit step_parallel(model, data_batch_t0)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"VNwzT2JQ1n15\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# @title Run parallel simulation\\n\",\n    \"\\n\",\n    \"data = data_batch_t0\\n\",\n    \"data_trajectory_list = []\\n\",\n    \"\\n\",\n    \"for _ in T:\\n\",\n    \"\\n\",\n    \"    data = step_parallel(model, data)\\n\",\n    \"    data_trajectory_list.append(data)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"Y6n720Cr3G44\"\n   },\n   \"source\": [\n    \"## Visualize trajectory\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"BLPODyKr3Lyg\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Convert a list of PyTrees to a batched PyTree.\\n\",\n    \"# This operation is called 'tree transpose' in JAX.\\n\",\n    \"data_trajectory = jax.tree.map(lambda *leafs: jnp.stack(leafs), *data_trajectory_list)\\n\",\n    \"\\n\",\n    \"print(f\\\"W_p_B: shape={data_trajectory.base_position.shape}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"-jxJXy5r3RMt\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import matplotlib.pyplot as plt\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"plt.plot(T, data_trajectory.base_position[:, :, 2])\\n\",\n    \"plt.grid(True)\\n\",\n    \"plt.xlabel(\\\"Time [s]\\\")\\n\",\n    \"plt.ylabel(\\\"Height [m]\\\")\\n\",\n    \"plt.title(\\\"Height trajectory of the sphere\\\")\\n\",\n    \"plt.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import jaxsim.mujoco\\n\",\n    \"\\n\",\n    \"mjcf_string, assets = jaxsim.mujoco.ModelToMjcf.convert(\\n\",\n    \"    model.built_from,\\n\",\n    \"    cameras=jaxsim.mujoco.loaders.MujocoCamera.build_from_target_view(\\n\",\n    \"        camera_name=\\\"sphere_cam\\\",\\n\",\n    \"        lookat=[0, 0, 0.3],\\n\",\n    \"        distance=4,\\n\",\n    \"        azimuth=150,\\n\",\n    \"        elevation=-10,\\n\",\n    \"    ),\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Create a helper for each parallel instance.\\n\",\n    \"mj_model_helpers = [\\n\",\n    \"    jaxsim.mujoco.MujocoModelHelper.build_from_xml(\\n\",\n    \"        mjcf_description=mjcf_string, assets=assets\\n\",\n    \"    )\\n\",\n    \"    for _ in range(batch_size)\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"# Create the video recorder.\\n\",\n    \"recorder = jaxsim.mujoco.MujocoVideoRecorder(\\n\",\n    \"    model=mj_model_helpers[0].model,\\n\",\n    \"    data=[helper.data for helper in mj_model_helpers],\\n\",\n    \"    fps=int(1 / model.time_step),\\n\",\n    \"    width=320 * 2,\\n\",\n    \"    height=240 * 2,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"for data_t in data_trajectory_list:\\n\",\n    \"\\n\",\n    \"    for helper, base_position, base_quaternion, joint_position in zip(\\n\",\n    \"        mj_model_helpers,\\n\",\n    \"        data_t.base_position,\\n\",\n    \"        data_t.base_orientation,\\n\",\n    \"        data_t.joint_positions,\\n\",\n    \"        strict=True,\\n\",\n    \"    ):\\n\",\n    \"        helper.set_base_position(position=base_position)\\n\",\n    \"        helper.set_base_orientation(orientation=base_quaternion)\\n\",\n    \"\\n\",\n    \"        if model.dofs() > 0:\\n\",\n    \"            helper.set_joint_positions(\\n\",\n    \"                positions=joint_position, joint_names=model.joint_names()\\n\",\n    \"            )\\n\",\n    \"\\n\",\n    \"    # Record a new video frame.\\n\",\n    \"    recorder.record_frame(camera_name=\\\"sphere_cam\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import mediapy as media\\n\",\n    \"\\n\",\n    \"media.show_video(recorder.frames, fps=recorder.fps)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"accelerator\": \"GPU\",\n  \"colab\": {\n   \"gpuClass\": \"premium\",\n   \"private_outputs\": true,\n   \"provenance\": [],\n   \"toc_visible\": true\n  },\n  \"kernelspec\": {\n   \"display_name\": \"jaxpypi\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.13.1\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "examples/jaxsim_for_robot_controllers.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"EhPy6FgiZH4d\"\n   },\n   \"source\": [\n    \"# JaxSim for developing closed-loop robot controllers\\n\",\n    \"\\n\",\n    \"Originally developed as a **hardware-accelerated physics engine**, JaxSim has expanded its capabilities to become a full-featured **JAX-based multibody dynamics library**.\\n\",\n    \"\\n\",\n    \"In this notebook, you'll explore how to combine these two core features. Specifically, you'll learn how to load a robot model and design a model-based controller for closed-loop simulations.\\n\",\n    \"\\n\",\n    \"<a target=\\\"_blank\\\" href=\\\"https://colab.research.google.com/github/gbionics/jaxsim/blob/main/examples/jaxsim_for_robot_controllers.ipynb\\\">\\n\",\n    \"  <img src=\\\"https://colab.research.google.com/assets/colab-badge.svg\\\" alt=\\\"Open In Colab\\\"/>\\n\",\n    \"</a>\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"vsf1AlxdZH4f\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# @title Prepare the environment\\n\",\n    \"from IPython.display import clear_output\\n\",\n    \"import sys\\n\",\n    \"\\n\",\n    \"IS_COLAB = \\\"google.colab\\\" in sys.modules\\n\",\n    \"\\n\",\n    \"# Install JAX, sdformat, and other notebook dependencies.\\n\",\n    \"if IS_COLAB:\\n\",\n    \"    !{sys.executable} -m pip install --pre -qU jaxsim[viz]\\n\",\n    \"    !apt install -qq lsb-release wget gnupg\\n\",\n    \"    !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg\\n\",\n    \"    !echo \\\"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main\\\" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null\\n\",\n    \"    !apt -qq update\\n\",\n    \"    !apt install -qq --no-install-recommends libsdformat13 gz-tools2\\n\",\n    \"\\n\",\n    \"    clear_output()\\n\",\n    \"\\n\",\n    \"# ================\\n\",\n    \"# Notebook imports\\n\",\n    \"# ================\\n\",\n    \"\\n\",\n    \"import os\\n\",\n    \"\\n\",\n    \"if sys.platform == 'darwin':\\n\",\n    \"    os.environ[\\\"MUJOCO_GL\\\"] = \\\"glfw\\\"\\n\",\n    \"else:\\n\",\n    \"    os.environ[\\\"MUJOCO_GL\\\"] = \\\"egl\\\"\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"import jax\\n\",\n    \"import jax.numpy as jnp\\n\",\n    \"import jaxsim.mujoco\\n\",\n    \"from jaxsim import logging\\n\",\n    \"\\n\",\n    \"logging.set_logging_level(logging.LoggingLevel.WARNING)\\n\",\n    \"print(f\\\"Running on {jax.devices()}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"kN-b9nOsZH4g\"\n   },\n   \"source\": [\n    \"We will use a simple cartpole model for this example. The cartpole model is a 2D model with a cart that can move horizontally and a pole that can rotate around the cart. The state of the cartpole is given by the position of the cart, the angle of the pole, the velocity of the cart, and the angular velocity of the pole. The control input is the horizontal force applied to the cart.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"5aLqrZDqR5LA\"\n   },\n   \"source\": [\n    \"## Prepare the simulation\\n\",\n    \"\\n\",\n    \"JaxSim supports loading robot models from both [SDF][sdformat] and [URDF][urdf] files, utilizing the [`gbionics/rod`][rod] library for processing these formats.\\n\",\n    \"\\n\",\n    \"The `rod` library library can read URDF files and validates them internally using [`gazebosim/sdformat`][sdformat_github]. In this example, we'll load a cart-pole model, which will be used to create the JaxSim simulation model.\\n\",\n    \"\\n\",\n    \"[sdformat]: http://sdformat.org/\\n\",\n    \"[urdf]: http://wiki.ros.org/urdf/\\n\",\n    \"[rod]: https://github.com/gbionics/rod\\n\",\n    \"[sdformat_github]: https://github.com/gazebosim/sdformat\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"\\n\",\n    \"os.path.abspath(\\\"\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"PZM7hEvFZH4h\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# @title Load the URDF model\\n\",\n    \"import pathlib\\n\",\n    \"import urllib\\n\",\n    \"\\n\",\n    \"# Retrieve the file\\n\",\n    \"url = \\\"https://raw.githubusercontent.com/gbionics/jaxsim/refs/heads/main/examples/assets/cartpole.urdf\\\"\\n\",\n    \"model_path, _ = urllib.request.urlretrieve(url)\\n\",\n    \"model_urdf_string = pathlib.Path(model_path).read_text()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"M5XsKehvZH4j\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# @title Create the model and its data\\n\",\n    \"\\n\",\n    \"import jaxsim.api as js\\n\",\n    \"\\n\",\n    \"# Create the model from the model description.\\n\",\n    \"model = js.model.JaxSimModel.build_from_model_description(\\n\",\n    \"    model_description=model_urdf_string,\\n\",\n    \"    time_step=0.010,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Create the data storing the simulation state.\\n\",\n    \"data_zero = js.data.JaxSimModelData.zero(model=model)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"jk9csR5ETgn1\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# @title Define simulation parameters\\n\",\n    \"\\n\",\n    \"# Initialize the simulated time.\\n\",\n    \"T = jnp.arange(start=0, stop=5.0, step=model.time_step)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"bo6Ke5nAWL-S\"\n   },\n   \"source\": [\n    \"## Prepare the MuJoCo renderer\\n\",\n    \"\\n\",\n    \"For visualization purpose, we use the passive viewer of the MuJoCo simulator. It allows to either open an interactive windows when used locally or record a video when used in notebooks.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"j1_I2i5TZH4n\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Create the MJCF resources from the URDF.\\n\",\n    \"mjcf_string, assets = jaxsim.mujoco.UrdfToMjcf.convert(\\n\",\n    \"    urdf=model.built_from,\\n\",\n    \"    # Create the camera used by the recorder.\\n\",\n    \"    cameras=jaxsim.mujoco.loaders.MujocoCamera.build_from_target_view(\\n\",\n    \"        camera_name=\\\"cartpole_camera\\\",\\n\",\n    \"        lookat=js.link.com_position(\\n\",\n    \"            model=model,\\n\",\n    \"            data=data_zero,\\n\",\n    \"            link_index=js.link.name_to_idx(model=model, link_name=\\\"cart\\\"),\\n\",\n    \"            in_link_frame=False,\\n\",\n    \"        ),\\n\",\n    \"        distance=3,\\n\",\n    \"        azimuth=150,\\n\",\n    \"        elevation=-10,\\n\",\n    \"    ),\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Create a helper to operate on the MuJoCo model and data.\\n\",\n    \"mj_model_helper = jaxsim.mujoco.MujocoModelHelper.build_from_xml(\\n\",\n    \"    mjcf_description=mjcf_string, assets=assets\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Create the video recorder.\\n\",\n    \"recorder = jaxsim.mujoco.MujocoVideoRecorder(\\n\",\n    \"    model=mj_model_helper.model,\\n\",\n    \"    data=mj_model_helper.data,\\n\",\n    \"    fps=int(1 / model.time_step),\\n\",\n    \"    width=320 * 2,\\n\",\n    \"    height=240 * 2,\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"DpRvvGujZH4o\"\n   },\n   \"source\": [\n    \"## Open-loop simulation\\n\",\n    \"\\n\",\n    \"Now, let's run a simulation to demonstrate the open-loop dynamics of the system.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"gSWzcsKWZH4p\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import mediapy as media\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Create a random joint position.\\n\",\n    \"# For a random full state, you can use jaxsim.api.data.random_model_data.\\n\",\n    \"random_joint_positions = jax.random.uniform(\\n\",\n    \"    minval=-1.0,\\n\",\n    \"    maxval=1.0,\\n\",\n    \"    shape=(model.dofs(),),\\n\",\n    \"    key=jax.random.PRNGKey(0),\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# Reset the state to the random joint positions.\\n\",\n    \"data = js.data.JaxSimModelData.build(model=model, joint_positions=random_joint_positions)\\n\",\n    \"\\n\",\n    \"for _ in T:\\n\",\n    \"\\n\",\n    \"    # Step the JaxSim simulation.\\n\",\n    \"    data = js.model.step(\\n\",\n    \"        model=model,\\n\",\n    \"        data=data,\\n\",\n    \"        joint_force_references=None,\\n\",\n    \"        link_forces=None,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    # Update the MuJoCo data.\\n\",\n    \"    mj_model_helper.set_joint_positions(\\n\",\n    \"        positions=data.joint_positions, joint_names=model.joint_names()\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    # Record a new video frame.\\n\",\n    \"    recorder.record_frame(camera_name=\\\"cartpole_camera\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# Play the video.\\n\",\n    \"media.show_video(recorder.frames, fps=recorder.fps)\\n\",\n    \"recorder.frames = []\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"j1rguK3UZH4p\"\n   },\n   \"source\": [\n    \"## Closed-loop simulation\\n\",\n    \"\\n\",\n    \"Next, let's design a simple computed torque controller. The equations of motion for the cart-pole system are given by:\\n\",\n    \"\\n\",\n    \"$$\\n\",\n    \"M_{ss}(\\\\mathbf{s}) \\\\, \\\\ddot{\\\\mathbf{s}} + \\\\mathbf{h}_s(\\\\mathbf{s}, \\\\dot{\\\\mathbf{s}}) = \\\\boldsymbol{\\\\tau}\\n\",\n    \",\\n\",\n    \"$$\\n\",\n    \"\\n\",\n    \"where:\\n\",\n    \"\\n\",\n    \"- $\\\\mathbf{s} \\\\in \\\\mathbb{R}^n$ are the joint positions.\\n\",\n    \"- $\\\\dot{\\\\mathbf{s}} \\\\in \\\\mathbb{R}^n$ are the joint velocities.\\n\",\n    \"- $\\\\ddot{\\\\mathbf{s}} \\\\in \\\\mathbb{R}^n$ are the joint accelerations.\\n\",\n    \"- $\\\\boldsymbol{\\\\tau} \\\\in \\\\mathbb{R}^n$ are the joint torques.\\n\",\n    \"- $M_{ss} \\\\in \\\\mathbb{R}^{n \\\\times n}$ is the mass matrix.\\n\",\n    \"- $\\\\mathbf{h}_s \\\\in \\\\mathbb{R}^n$ is the vector of bias forces.\\n\",\n    \"\\n\",\n    \"JaxSim computes these quantities for floating-base systems, so we specifically focus on the joint-related portions by marking them with subscripts.\\n\",\n    \"\\n\",\n    \"Since no external forces or joint friction are present, we can extend a PD controller with a feed-forward term that includes gravity compensation:\\n\",\n    \"\\n\",\n    \"$$\\n\",\n    \"\\\\begin{cases}\\n\",\n    \"\\\\boldsymbol{\\\\tau} &= M_{ss} \\\\, \\\\ddot{\\\\mathbf{s}}^* + \\\\mathbf{h}_s \\\\\\\\\\n\",\n    \"\\\\ddot{\\\\mathbf{s}}^* &= \\\\ddot{\\\\mathbf{s}}^\\\\text{des} - k_p(\\\\mathbf{s} - \\\\mathbf{s}^{\\\\text{des}}) - k_d(\\\\mathbf{s}^{\\\\text{des}} - \\\\dot{\\\\mathbf{s}}^{\\\\text{des}})\\n\",\n    \"\\\\end{cases}\\n\",\n    \"\\\\quad\\n\",\n    \",\\n\",\n    \"$$\\n\",\n    \"\\n\",\n    \"where $\\\\tilde{\\\\mathbf{s}} = \\\\left(\\\\mathbf{s} - \\\\mathbf{s}^\\\\text{des}\\\\right)$ is the joint position error.\\n\",\n    \"\\n\",\n    \"With this control law, the closed-loop system dynamics simplifies to:\\n\",\n    \"\\n\",\n    \"$$\\n\",\n    \"\\\\ddot{\\\\tilde{\\\\mathbf{s}}} = -k_p \\\\tilde{\\\\mathbf{s}} - k_d \\\\dot{\\\\tilde{\\\\mathbf{s}}}\\n\",\n    \",\\n\",\n    \"$$\\n\",\n    \"\\n\",\n    \"which converges asymptotically to zero, ensuring stability.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"rfMTCMyGZH4q\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# @title Create the computed torque controller\\n\",\n    \"\\n\",\n    \"# Define the PD gains\\n\",\n    \"kp = 10.0\\n\",\n    \"kd = 6.0\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def computed_torque_controller(\\n\",\n    \"    data: js.data.JaxSimModelData,\\n\",\n    \"    s_des: jax.Array,\\n\",\n    \"    s_dot_des: jax.Array,\\n\",\n    \") -> jax.Array:\\n\",\n    \"\\n\",\n    \"    # Compute the gravity compensation term.\\n\",\n    \"    hs = js.model.free_floating_bias_forces(model=model, data=data)[6:]\\n\",\n    \"\\n\",\n    \"    # Compute the joint-related portion of the floating-base mass matrix.\\n\",\n    \"    Mss = js.model.free_floating_mass_matrix(model=model, data=data)[6:, 6:]\\n\",\n    \"\\n\",\n    \"    # Get the current joint positions and velocities.\\n\",\n    \"    s = data.joint_positions\\n\",\n    \"    ṡ = data.joint_velocities\\n\",\n    \"\\n\",\n    \"    # Compute the actuated joint torques.\\n\",\n    \"    s_star = -kp * (s - s_des) - kd * (ṡ - s_dot_des)\\n\",\n    \"    τ = Mss @ s_star + hs\\n\",\n    \"\\n\",\n    \"    return τ\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"ERAUisywZH4q\"\n   },\n   \"source\": [\n    \"Now, we can use the `pd_controller` function to compute the torque to apply to the cartpole. Our aim is to stabilize the cartpole in the upright position, so we set the desired position `q_d` to 0 and the desired velocity `q_dot_d` to 0.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"8YmDdGDVZH4q\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# @title Run the simulation\\n\",\n    \"\\n\",\n    \"# Initialize the data.\\n\",\n    \"\\n\",\n    \"# Set the joint positions.\\n\",\n    \"data = js.data.JaxSimModelData.build(model=model, joint_positions=jnp.array([-0.25, jnp.deg2rad(160)]), joint_velocities=jnp.array([3.00, jnp.deg2rad(10) / model.time_step]))\\n\",\n    \"\\n\",\n    \"for _ in T:\\n\",\n    \"\\n\",\n    \"    # Get the actuated torques from the computed torque controller.\\n\",\n    \"    τ = computed_torque_controller(\\n\",\n    \"        data=data,\\n\",\n    \"        s_des=jnp.array([0.0, 0.0]),\\n\",\n    \"        s_dot_des=jnp.array([0.0, 0.0]),\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    # Step the JaxSim simulation.\\n\",\n    \"    data = js.model.step(\\n\",\n    \"        model=model,\\n\",\n    \"        data=data,\\n\",\n    \"        joint_force_references=τ,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    # Update the MuJoCo data.\\n\",\n    \"    mj_model_helper.set_joint_positions(\\n\",\n    \"        positions=data.joint_positions, joint_names=model.joint_names()\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    # Record a new video frame.\\n\",\n    \"    recorder.record_frame(camera_name=\\\"cartpole_camera\\\")\\n\",\n    \"\\n\",\n    \"media.show_video(recorder.frames, fps=recorder.fps)\\n\",\n    \"recorder.frames = []\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"sZ76QqeWeMQz\"\n   },\n   \"source\": [\n    \"## Conclusions\\n\",\n    \"\\n\",\n    \"In this notebook, we explored how to use JaxSim for developing a closed-loop controller for a robot model. Key takeaways include:\\n\",\n    \"\\n\",\n    \"- We performed an open-loop simulation to understand the dynamics of the system without control.\\n\",\n    \"- We implemented a computed torque controller with PD feedback and a feed-forward gravity compensation term, enabling the stabilization of the system by controlling joint torques.\\n\",\n    \"- The closed-loop simulation can leverage hardware acceleration on GPUs and TPUs, with the ability to use `jax.vmap` for parallel sampling through automatic vectorization.\\n\",\n    \"\\n\",\n    \"JaxSim's closed-loop support can be extended to more advanced, model-based reactive controllers and planners for trajectory optimization. To explore optimization-based methods, consider the following JAX-based projects for hardware-accelerated control and planning:\\n\",\n    \"\\n\",\n    \"- [`deepmind/optax`](https://github.com/google-deepmind/optax)\\n\",\n    \"- [`google/jaxopt`](https://github.com/google/jaxopt)\\n\",\n    \"- [`patrick-kidger/lineax`](https://github.com/patrick-kidger/lineax)\\n\",\n    \"- [`patrick-kidger/optimistix`](https://github.com/patrick-kidger/optimistix)\\n\",\n    \"- [`kevin-tracy/qpax`](https://github.com/kevin-tracy/qpax)\\n\",\n    \"\\n\",\n    \"Additionally, if your controllers or planners require the derivatives of the dynamics with respect to the state or inputs, you can obtain them using automatic differentiation directly through JaxSim's API.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"accelerator\": \"GPU\",\n  \"colab\": {\n   \"gpuClass\": \"premium\",\n   \"private_outputs\": true,\n   \"provenance\": [],\n   \"toc_visible\": true\n  },\n  \"kernelspec\": {\n   \"display_name\": \"comodo_jaxsim\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.12.8\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[project]\nname = \"jaxsim\"\ndynamic = [\"version\"]\nrequires-python = \">= 3.10\"\ndescription = \"A differentiable physics engine and multibody dynamics library for control and robot learning.\"\nauthors = [\n    { name = \"Diego Ferigo\", email = \"dgferigo@gmail.com\" },\n    { name = \"Filippo Luca Ferretti\", email = \"filippoluca.ferretti@outlook.com\" },\n]\nmaintainers = [\n    { name = \"Filippo Luca Ferretti\", email = \"filippo.ferretti@outlook.com\" },\n]\nlicense = \"BSD-3-Clause\"\nlicense-files = [\"LICENSE\"]\nkeywords = [\n    \"physics\",\n    \"physics engine\",\n    \"jax\",\n    \"rigid body dynamics\",\n    \"featherstone\",\n    \"reinforcement learning\",\n    \"robot\",\n    \"robotics\",\n    \"sdf\",\n    \"urdf\",\n]\nclassifiers = [\n    \"Development Status :: 4 - Beta\",\n    \"Framework :: Robot Framework\",\n    \"Intended Audience :: Developers\",\n    \"Intended Audience :: Science/Research\",\n    \"Operating System :: POSIX :: Linux\",\n    \"Operating System :: MacOS\",\n    \"Operating System :: Microsoft\",\n    \"Programming Language :: Python :: 3 :: Only\",\n    \"Programming Language :: Python :: 3.10\",\n    \"Programming Language :: Python :: 3.11\",\n    \"Programming Language :: Python :: 3.12\",\n    \"Programming Language :: Python :: 3.13\",\n    \"Programming Language :: Python :: 3.14\",\n    \"Programming Language :: Python :: Implementation :: CPython\",\n    \"Topic :: Games/Entertainment :: Simulation\",\n    \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n    \"Topic :: Scientific/Engineering :: Physics\",\n    \"Topic :: Software Development\",\n]\ndependencies = [\n    \"coloredlogs\",\n    \"jax >= 0.4.34\",\n    \"jaxlib >= 0.4.34\",\n    \"jaxlie >= 1.3.0\",\n    \"jax_dataclasses >= 1.4.0\",\n    \"pptree\",\n    \"optax >= 0.2.3\",\n    \"qpax\",\n    \"rod >= 0.4.1\",\n    \"typing_extensions ; python_version < '3.12'\",\n    \"trimesh\",\n]\n\n[project.optional-dependencies]\ntesting = [\n    \"chex >= 0.1.91\",\n    \"idyntree >= 12.2.1\",\n    \"pytest >=6.0\",\n    \"pytest-benchmark\",\n    \"pytest-icdiff\",\n    \"pytest-xdist\",\n    \"robot-descriptions >= 1.16.0\",\n    \"icub-models\",\n]\nviz = [\n    \"lxml\",\n    \"mediapy\",\n    \"mujoco >= 3.0.0\",\n    \"scipy >= 1.14.0\",\n]\nall = [\n    \"jaxsim[testing,viz]\",\n]\n\n[project.readme]\nfile = \"README.md\"\ncontent-type = \"text/markdown\"\n\n[project.urls]\nChangelog = \"https://github.com/gbionics/jaxsim/releases\"\nDocumentation = \"https://jaxsim.readthedocs.io\"\nSource = \"https://github.com/gbionics/jaxsim\"\nTracker = \"https://github.com/gbionics/jaxsim/issues\"\n\n# ===========\n# Build tools\n# ===========\n\n[build-system]\nbuild-backend = \"hatchling.build\"\nrequires = [\n    \"hatchling\",\n    \"hatch-vcs\",\n]\n\n[tool.hatch.version]\nsource = \"vcs\"\nraw-options = { local_scheme = \"dirty-tag\" }\n\n[tool.hatch.build.targets.wheel]\npackages = [\"src/jaxsim\"]\n\n[tool.hatch.build.hooks.vcs]\nversion-file = \"src/jaxsim/_version.py\"\n\n# =================\n# Style and testing\n# =================\n\n[tool.black]\nline-length = 88\n\n[tool.isort]\nmulti_line_output = 3\nprofile = \"black\"\n\n[tool.pytest.ini_options]\naddopts = \"-rsxX -v --strict-markers --benchmark-skip --benchmark-warmup=ON\"\nminversion = \"6.0\"\ntestpaths = [\n    \"tests\",\n]\n\n# ==================\n# Ruff configuration\n# ==================\n\n[tool.ruff]\nexclude = [\n    \".git\",\n    \".pixi\",\n    \".pytest_cache\",\n    \".ruff_cache\",\n    \".idea\",\n    \".vscode\",\n    \".devcontainer\",\n    \"__pycache__\",\n]\npreview = true\n\n[tool.ruff.lint]\n# https://docs.astral.sh/ruff/rules/\nselect = [\n    \"B\",\n    \"D\",\n    \"E\",\n    \"F\",\n    \"I\",\n    \"W\",\n    \"RUF\",\n    \"UP\",\n    \"YTT\",\n]\n\nignore = [\n    \"B008\", # Function call in default argument\n    \"B024\", # Abstract base class without abstract methods\n    \"D100\", # Missing docstring in public module\n    \"D104\", # Missing docstring in public package\n    \"D105\", # Missing docstring in magic method\n    \"D200\", # One-line docstring should fit on one line with quotes\n    \"D202\", # No blank lines allowed after function docstring\n    \"D203\", # Incorrect blank line before class\n    \"D205\", # 1 blank line required between summary line and description\n    \"D212\", # Multi-line docstring summary should start at the first line\n    \"D411\", # Missing blank line before section\n    \"D413\", # Missing blank line after last section\n    \"E402\", # Module level import not at top of file\n    \"E501\", # Line too long\n    \"E731\", # Do not assign a `lambda` expression, use a `def`\n    \"E741\", # Ambiguous variable name\n    \"I001\", # Import block is unsorted or unformatted\n    \"RUF003\", # Ambiguous unicode character in comment\n]\n\n[tool.ruff.lint.per-file-ignores]\n# Ignore `E402` (import violations) in all `__init__.py` files\n\"**/{tests,docs,tools}/*\" = [\"E402\"]\n\"**/{tests,examples}/*\" = [\"B007\", \"D100\", \"D102\", \"D103\"]\n\"__init__.py\" = [\"F401\", \"RUF067\"]\n\"docs/conf.py\" = [\"F401\"]\n\"src/jaxsim/exceptions.py\" = [\"D401\"]\n\"src/jaxsim/logging.py\" = [\"D101\", \"D103\"]\n\n# ==================\n# Pixi configuration\n# ==================\n\n[tool.pixi.workspace]\nchannels = [\"conda-forge\"]\nplatforms = [\"linux-64\", \"linux-aarch64\", \"osx-arm64\", \"osx-64\"]\nrequires-pixi = \">=0.39.0\"\npreview = [\"pixi-build\"]\n\n[tool.pixi.environments]\n# We resolve only two groups: cpu and gpu.\n# Then, multiple environments can be created from these groups.\ndefault = { features = [\"test\", \"examples\"] }\ngpu = { features = [\"test\", \"examples\", \"gpu\"] }\n\n# ---------------\n# feature.default\n# ---------------\n\n# Dependencies from conda-forge.\n[tool.pixi.dependencies]\n#\n# Matching `project.dependencies`.\n#\ncoloredlogs = \"*\"\njax = \"*\"\njaxlib = \"*\"\njaxlie = \"*\"\njax-dataclasses = \"*\"\npptree = \"*\"\noptax = \"*\"\nqpax = \"*\"\nrod = \">=0.4.1\"\ntrimesh = \"*\"\ntyping_extensions = \"*\"\n#\n# Optional dependencies.\n#\nlxml = \"*\"\nmediapy = \"*\"\nmujoco = \"*\"\nscipy = \"*\"\n#\n# Additional dependencies.\n#\npip = \"*\"\nhatchling = \"*\"\nhatch-vcs = \"*\"\n\n# Dependencies from PyPI.\n[tool.pixi.pypi-dependencies]\njaxsim = { path = \"./\", editable = true }\n\n[tool.pixi.pypi-options]\nno-build-isolation = [\"jaxsim\"]\n\n# ------------\n# feature.test\n# ------------\n\n[tool.pixi.feature.test.tasks]\npipcheck = \"pip check\"\nbenchmark = { cmd = \"pytest --benchmark-only --benchmark-warmup=ON\", depends-on = [\"pipcheck\"] }\ntest = { cmd = \"pytest\", depends-on = [\"pipcheck\"] }\n\n[tool.pixi.feature.test.dependencies]\nblack-jupyter = \"*\"\nchex = \">=0.1.91\"\nidyntree = \"*\"\nisort = \"*\"\npre-commit = \"*\"\npytest = \"*\"\npytest-benchmark = \"*\"\npytest-icdiff = \"*\"\npytest-xdist = \"*\"\nrobot_descriptions = \">=1.16.0\"\n\n# ----------------\n# feature.examples\n# ----------------\n\n[tool.pixi.feature.examples.tasks]\nexamples = { cmd = \"jupyter notebook ./examples\" }\n\n[tool.pixi.feature.examples.dependencies]\nnotebook = \"*\"\nrobot_descriptions = \">=1.16.0\"\n\n# -----------\n# feature.gpu\n# -----------\n\n[tool.pixi.feature.gpu]\nplatforms = [\"linux-64\"]\nsystem-requirements = { cuda = \"13\" }\n\n[tool.pixi.feature.gpu.dependencies]\njaxlib = { version = \"*\", build = \"*cuda*\" }\n\n[tool.pixi.feature.gpu.tasks]\ntest-gpu = { cmd = \"pytest --gpu-only\", depends-on = [\"pipcheck\"] }\n"
  },
  {
    "path": "src/jaxsim/__init__.py",
    "content": "from . import logging\nfrom ._version import __version__\n\n\n# Follow upstream development in https://github.com/google/jax/pull/13304\ndef _jnp_options() -> None:\n    import os\n\n    import jax\n\n    # Check if running on TPU.\n    is_tpu = jax.devices()[0].platform == \"tpu\"\n\n    # Check if running on Metal.\n    is_metal = jax.devices()[0].platform == \"METAL\"\n\n    # Enable by default 64-bit precision to get accurate physics.\n    # Users can enforce 32-bit precision by setting the following variable to 0.\n    use_x64 = os.environ.get(\"JAX_ENABLE_X64\", \"1\") != \"0\"\n\n    # Notify the user if unsupported 64-bit precision was enforced on TPU.\n    if (is_tpu or is_metal) and use_x64:\n        msg = f\"64-bit precision is not allowed on {jax.devices()[0].platform.upper}. Enforcing 32bit precision.\"\n        logging.warning(msg)\n        use_x64 = False\n\n        if is_metal:\n            logging.warning(\n                \"JAX Metal backend is experimental. Some functionalities may not be available.\"\n            )\n\n    # Enable 64-bit precision in JAX.\n    if use_x64:\n        logging.info(\"Enabling JAX to use 64-bit precision\")\n        jax.config.update(\"jax_enable_x64\", True)\n\n    # Warn about experimental usage of 32-bit precision.\n    else:\n        logging.warning(\n            \"Using 32-bit precision in JaxSim is still experimental, please avoid to use variable step integrators.\"\n        )\n\n\ndef _np_options() -> None:\n    import numpy as np\n\n    np.set_printoptions(precision=5, suppress=True, linewidth=150, threshold=10_000)\n\n\ndef _is_editable() -> bool:\n\n    import importlib.util\n    import pathlib\n    import site\n\n    # Get the ModuleSpec of jaxsim.\n    jaxsim_spec = importlib.util.find_spec(name=\"jaxsim\")\n\n    # This can be None. If it's None, assume non-editable installation.\n    if jaxsim_spec.origin is None:\n        return False\n\n    # Get the folder containing the jaxsim package.\n    jaxsim_package_dir = str(pathlib.Path(jaxsim_spec.origin).parent.parent)\n\n    # The installation is editable if the package dir is not in any {site|dist}-packages.\n    return jaxsim_package_dir not in site.getsitepackages()\n\n\ndef _get_default_logging_level() -> logging.LoggingLevel:\n    \"\"\"\n    Get the default logging level.\n\n    Returns:\n        The logging level to set.\n    \"\"\"\n\n    import os\n    import sys\n\n    # Allow to override the default logging level with an environment variable.\n    if overriden_logging_level := os.environ.get(\"JAXSIM_LOGGING_LEVEL\"):\n        try:\n            return logging.LoggingLevel[overriden_logging_level.upper()]\n\n        except KeyError as exc:\n            msg = \"Invalid logging level defined in JAXSIM_LOGGING_LEVEL\"\n            raise RuntimeError(msg) from exc\n\n    # If running under a debugger, set the logging level to DEBUG.\n    if getattr(sys, \"gettrace\", lambda: None)():\n        return logging.LoggingLevel.DEBUG\n\n    # If not running under a debugger, set the logging level to INFO or WARNING.\n    # INFO for editable installations, WARNING for non-editable installations.\n    # This is to avoid too verbose logging in non-editable installations.\n    return (\n        logging.LoggingLevel.INFO\n        if _is_editable()  # noqa: F821\n        else logging.LoggingLevel.WARNING\n    )\n\n\n# Configure the logger with the default logging level.\nlogging.configure(level=_get_default_logging_level())\n\n\n# Configure JAX.\n_jnp_options()\n\n# Initialize the numpy print options.\n_np_options()\n\ndel _jnp_options\ndel _np_options\ndel _get_default_logging_level\ndel _is_editable\n\nfrom . import terrain  # isort:skip\nfrom . import api, logging, math, rbda\nfrom .api.common import VelRepr\n"
  },
  {
    "path": "src/jaxsim/api/__init__.py",
    "content": "from . import common  # isort:skip\nfrom . import model, data  # isort:skip\nfrom . import (\n    actuation_model,\n    com,\n    contact,\n    frame,\n    integrators,\n    joint,\n    kin_dyn_parameters,\n    link,\n    ode,\n    references,\n)\n"
  },
  {
    "path": "src/jaxsim/api/actuation_model.py",
    "content": "import jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\n\n\ndef compute_resultant_torques(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    joint_force_references: jtp.Vector | None = None,\n) -> jtp.Vector:\n    \"\"\"\n    Compute the resultant torques acting on the joints.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        joint_force_references: The joint force references to apply.\n\n    Returns:\n        The resultant torques acting on the joints.\n    \"\"\"\n\n    # Build joint torques if not provided.\n    τ_references = (\n        jnp.atleast_1d(joint_force_references.squeeze())\n        if joint_force_references is not None\n        else jnp.zeros_like(data.joint_positions)\n    ).astype(float)\n\n    # ====================\n    # Enforce joint limits\n    # ====================\n\n    τ_position_limit = jnp.zeros_like(τ_references).astype(float)\n\n    if model.dofs() > 0:\n\n        # Stiffness and damper parameters for the joint position limits.\n        k_j = jnp.array(\n            model.kin_dyn_parameters.joint_parameters.position_limit_spring\n        ).astype(float)\n        d_j = jnp.array(\n            model.kin_dyn_parameters.joint_parameters.position_limit_damper\n        ).astype(float)\n\n        # Compute the joint position limit violations.\n        lower_violation = jnp.clip(\n            data.joint_positions\n            - model.kin_dyn_parameters.joint_parameters.position_limits_min,\n            max=0.0,\n        )\n\n        upper_violation = jnp.clip(\n            data.joint_positions\n            - model.kin_dyn_parameters.joint_parameters.position_limits_max,\n            min=0.0,\n        )\n\n        # Compute the joint position limit torque.\n        τ_position_limit -= jnp.diag(k_j) @ (lower_violation + upper_violation)\n\n        τ_position_limit -= (\n            jnp.positive(τ_position_limit) * jnp.diag(d_j) @ data.joint_velocities\n        )\n\n    # ====================\n    # Joint friction model\n    # ====================\n\n    τ_friction = jnp.zeros_like(τ_references).astype(float)\n\n    # Apply joint friction only if enabled in the actuation parameters.\n    if model.dofs() > 0 and model.actuation_params.enable_friction:\n\n        # Static and viscous joint friction parameters\n        kc = jnp.array(\n            model.kin_dyn_parameters.joint_parameters.friction_static\n        ).astype(float)\n        kv = jnp.array(\n            model.kin_dyn_parameters.joint_parameters.friction_viscous\n        ).astype(float)\n\n        # Compute the joint friction torque.\n        τ_friction = -(\n            jnp.diag(kc) @ jnp.sign(data.joint_velocities)\n            + jnp.diag(kv) @ data.joint_velocities\n        )\n\n    # ===============================\n    # Compute the total joint forces.\n    # ===============================\n\n    τ_total = τ_references + τ_friction + τ_position_limit\n    τ_lim = tn_curve_fn(model=model, data=data)\n    τ_total = jnp.clip(τ_total, -τ_lim, τ_lim)\n    return τ_total\n\n\ndef tn_curve_fn(\n    model: js.model.JaxSimModel, data: js.data.JaxSimModelData\n) -> jtp.Vector:\n    \"\"\"\n    Compute the torque limits using the tn curve.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The torque limits.\n    \"\"\"\n\n    τ_max = model.actuation_params.torque_max  # Max torque (Nm)\n    ω_th = model.actuation_params.omega_th  # Threshold speed (rad/s)\n    ω_max = model.actuation_params.omega_max  # Max speed for torque drop-off (rad/s)\n    abs_vel = jnp.abs(data.joint_velocities)\n    τ_lim = jnp.where(\n        abs_vel <= ω_th,\n        τ_max,\n        jnp.where(\n            abs_vel <= ω_max, τ_max * (1 - (abs_vel - ω_th) / (ω_max - ω_th)), 0.0\n        ),\n    )\n    return τ_lim\n"
  },
  {
    "path": "src/jaxsim/api/com.py",
    "content": "import jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.math\nimport jaxsim.typing as jtp\n\nfrom .common import VelRepr\n\n\n@jax.jit\n@js.common.named_scope\ndef com_position(\n    model: js.model.JaxSimModel, data: js.data.JaxSimModelData\n) -> jtp.Vector:\n    \"\"\"\n    Compute the position of the center of mass of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The position of the center of mass of the model w.r.t. the world frame.\n    \"\"\"\n\n    m = js.model.total_mass(model=model)\n\n    W_H_L = data._link_transforms\n    W_H_B = data._base_transform\n    B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B)\n\n    def B_p̃_LCoM(i) -> jtp.Vector:\n        m = js.link.mass(model=model, link_index=i)\n        L_p_LCoM = js.link.com_position(\n            model=model, data=data, link_index=i, in_link_frame=True\n        )\n        return m * B_H_W @ W_H_L[i] @ jnp.hstack([L_p_LCoM, 1])\n\n    com_links = jax.vmap(B_p̃_LCoM)(jnp.arange(model.number_of_links()))\n\n    B_p̃_CoM = (1 / m) * com_links.sum(axis=0)\n    B_p̃_CoM = B_p̃_CoM.at[3].set(1)\n\n    return (W_H_B @ B_p̃_CoM)[0:3].astype(float)\n\n\n@jax.jit\n@js.common.named_scope\ndef com_linear_velocity(\n    model: js.model.JaxSimModel, data: js.data.JaxSimModelData\n) -> jtp.Vector:\n    r\"\"\"\n    Compute the linear velocity of the center of mass of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The linear velocity of the center of mass of the model in the\n        active representation.\n\n    Note:\n        The linear velocity of the center of mass  is expressed in the mixed frame\n        :math:`G = ({}^W \\mathbf{p}_{\\text{CoM}}, [C])`, where :math:`[C] = [W]` if the\n        active velocity representation is either inertial-fixed or mixed,\n        and :math:`[C] = [B]` if the active velocity representation is body-fixed.\n    \"\"\"\n\n    # Extract the linear component of the 6D average centroidal velocity.\n    # This is expressed in G[B] in body-fixed representation, and in G[W] in\n    # inertial-fixed or mixed representation.\n    G_vl_WG = average_centroidal_velocity(model=model, data=data)[0:3]\n\n    return G_vl_WG\n\n\n@jax.jit\n@js.common.named_scope\ndef centroidal_momentum(\n    model: js.model.JaxSimModel, data: js.data.JaxSimModelData\n) -> jtp.Vector:\n    r\"\"\"\n    Compute the centroidal momentum of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The centroidal momentum of the model.\n\n    Note:\n        The centroidal momentum is expressed in the mixed frame\n        :math:`({}^W \\mathbf{p}_{\\text{CoM}}, [C])`, where :math:`C = W` if the\n        active velocity representation is either inertial-fixed or mixed,\n        and :math:`C = B` if the active velocity representation is body-fixed.\n    \"\"\"\n\n    ν = data.generalized_velocity\n    G_J = centroidal_momentum_jacobian(model=model, data=data)\n\n    return G_J @ ν\n\n\n@jax.jit\n@js.common.named_scope\ndef centroidal_momentum_jacobian(\n    model: js.model.JaxSimModel, data: js.data.JaxSimModelData\n) -> jtp.Matrix:\n    r\"\"\"\n    Compute the Jacobian of the centroidal momentum of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The Jacobian of the centroidal momentum of the model.\n\n    Note:\n        The frame corresponding to the output representation of this Jacobian is either\n        :math:`G[W]`, if the active velocity representation is inertial-fixed or mixed,\n        or :math:`G[B]`, if the active velocity representation is body-fixed.\n\n    Note:\n        This Jacobian is also known in the literature as Centroidal Momentum Matrix.\n    \"\"\"\n\n    # Compute the Jacobian of the total momentum with body-fixed output representation.\n    # We convert the output representation either to G[W] or G[B] below.\n    B_Jh = js.model.total_momentum_jacobian(\n        model=model, data=data, output_vel_repr=VelRepr.Body\n    )\n\n    W_H_B = data._base_transform\n    B_H_W = jaxsim.math.Transform.inverse(W_H_B)\n\n    W_p_CoM = com_position(model=model, data=data)\n\n    match data.velocity_representation:\n        case VelRepr.Inertial | VelRepr.Mixed:\n            W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)  # noqa: F841\n        case VelRepr.Body:\n            W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM)  # noqa: F841\n        case _:\n            raise ValueError(data.velocity_representation)\n\n    # Compute the transform for 6D forces.\n    G_Xf_B = jaxsim.math.Adjoint.from_transform(transform=B_H_W @ W_H_G).T\n\n    return G_Xf_B @ B_Jh\n\n\n@jax.jit\n@js.common.named_scope\ndef locked_centroidal_spatial_inertia(\n    model: js.model.JaxSimModel, data: js.data.JaxSimModelData\n):\n    \"\"\"\n    Compute the locked centroidal spatial inertia of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The locked centroidal spatial inertia of the model.\n    \"\"\"\n\n    with data.switch_velocity_representation(VelRepr.Body):\n        B_Mbb_B = js.model.locked_spatial_inertia(model=model, data=data)\n\n    W_H_B = data._base_transform\n    W_p_CoM = com_position(model=model, data=data)\n\n    match data.velocity_representation:\n        case VelRepr.Inertial | VelRepr.Mixed:\n            W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)  # noqa: F841\n        case VelRepr.Body:\n            W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM)  # noqa: F841\n        case _:\n            raise ValueError(data.velocity_representation)\n\n    B_H_G = jaxsim.math.Transform.inverse(W_H_B) @ W_H_G\n\n    B_Xv_G = jaxsim.math.Adjoint.from_transform(transform=B_H_G)\n    G_Xf_B = B_Xv_G.transpose()\n\n    return G_Xf_B @ B_Mbb_B @ B_Xv_G\n\n\n@jax.jit\n@js.common.named_scope\ndef average_centroidal_velocity(\n    model: js.model.JaxSimModel, data: js.data.JaxSimModelData\n) -> jtp.Vector:\n    r\"\"\"\n    Compute the average centroidal velocity of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The average centroidal velocity of the model.\n\n    Note:\n        The average velocity is expressed in the mixed frame\n        :math:`G = ({}^W \\mathbf{p}_{\\text{CoM}}, [C])`, where :math:`[C] = [W]` if the\n        active velocity representation is either inertial-fixed or mixed,\n        and :math:`[C] = [B]` if the active velocity representation is body-fixed.\n    \"\"\"\n\n    ν = data.generalized_velocity\n    G_J = average_centroidal_velocity_jacobian(model=model, data=data)\n\n    return G_J @ ν\n\n\n@jax.jit\n@js.common.named_scope\ndef average_centroidal_velocity_jacobian(\n    model: js.model.JaxSimModel, data: js.data.JaxSimModelData\n) -> jtp.Matrix:\n    r\"\"\"\n    Compute the Jacobian of the average centroidal velocity of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The Jacobian of the average centroidal velocity of the model.\n\n    Note:\n        The frame corresponding to the output representation of this Jacobian is either\n        :math:`G[W]`, if the active velocity representation is inertial-fixed or mixed,\n        or :math:`G[B]`, if the active velocity representation is body-fixed.\n    \"\"\"\n\n    G_J = centroidal_momentum_jacobian(model=model, data=data)\n    G_Mbb = locked_centroidal_spatial_inertia(model=model, data=data)\n\n    return jnp.linalg.inv(G_Mbb) @ G_J\n\n\n@jax.jit\n@js.common.named_scope\ndef bias_acceleration(\n    model: js.model.JaxSimModel, data: js.data.JaxSimModelData\n) -> jtp.Vector:\n    r\"\"\"\n    Compute the bias linear acceleration of the center of mass.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The bias linear acceleration of the center of mass in the active representation.\n\n    Note:\n        The bias acceleration is expressed in the mixed frame\n        :math:`G = ({}^W \\mathbf{p}_{\\text{CoM}}, [C])`, where :math:`[C] = [W]` if the\n        active velocity representation is either inertial-fixed or mixed,\n        and :math:`[C] = [B]` if the active velocity representation is body-fixed.\n    \"\"\"\n\n    # Compute the pose of all links with forward kinematics.\n    W_H_L = data._link_transforms\n\n    # Compute the bias acceleration of all links by zeroing the generalized velocity\n    # in the active representation.\n    v̇_bias_WL = js.model.link_bias_accelerations(model=model, data=data)\n\n    def other_representation_to_body(\n        C_v̇_WL: jtp.Vector, C_v_WC: jtp.Vector, L_H_C: jtp.Matrix, L_v_LC: jtp.Vector\n    ) -> jtp.Vector:\n        \"\"\"\n        Convert the body-fixed representation of the link bias acceleration\n        C_v̇_WL expressed in a generic frame C to the body-fixed representation L_v̇_WL.\n        \"\"\"\n\n        L_X_C = jaxsim.math.Adjoint.from_transform(transform=L_H_C)\n        C_X_L = jaxsim.math.Adjoint.inverse(L_X_C)\n\n        L_v̇_WL = L_X_C @ (C_v̇_WL + jaxsim.math.Cross.vx(C_X_L @ L_v_LC) @ C_v_WC)\n        return L_v̇_WL\n\n    # We need here to get the body-fixed bias acceleration of the links.\n    # Since it's computed in the active representation, we need to convert it to body.\n    match data.velocity_representation:\n\n        case VelRepr.Body:\n            L_a_bias_WL = v̇_bias_WL\n\n        case VelRepr.Inertial:\n\n            C_v̇_WL = W_v̇_bias_WL = v̇_bias_WL  # noqa: F841\n            C_v_WC = W_v_WW = jnp.zeros(6)  # noqa: F841\n\n            L_H_C = L_H_W = jax.vmap(jaxsim.math.Transform.inverse)(W_H_L)  # noqa: F841\n\n            L_v_LC = L_v_LW = jax.vmap(  # noqa: F841\n                lambda i: -js.link.velocity(\n                    model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body\n                )\n            )(jnp.arange(model.number_of_links()))\n\n            L_a_bias_WL = jax.vmap(\n                lambda i: other_representation_to_body(\n                    C_v̇_WL=C_v̇_WL[i],\n                    C_v_WC=C_v_WC,\n                    L_H_C=L_H_C[i],\n                    L_v_LC=L_v_LC[i],\n                )\n            )(jnp.arange(model.number_of_links()))\n\n        case VelRepr.Mixed:\n\n            C_v̇_WL = LW_v̇_bias_WL = v̇_bias_WL  # noqa: F841\n\n            C_v_WC = LW_v_W_LW = jax.vmap(  # noqa: F841\n                lambda i: js.link.velocity(\n                    model=model, data=data, link_index=i, output_vel_repr=VelRepr.Mixed\n                )\n                .at[3:6]\n                .set(jnp.zeros(3))\n            )(jnp.arange(model.number_of_links()))\n\n            L_H_C = L_H_LW = jax.vmap(  # noqa: F841\n                lambda W_H_L: jaxsim.math.Transform.inverse(\n                    W_H_L.at[0:3, 3].set(jnp.zeros(3))\n                )\n            )(W_H_L)\n\n            L_v_LC = L_v_L_LW = jax.vmap(  # noqa: F841\n                lambda i: -js.link.velocity(\n                    model=model, data=data, link_index=i, output_vel_repr=VelRepr.Body\n                )\n                .at[0:3]\n                .set(jnp.zeros(3))\n            )(jnp.arange(model.number_of_links()))\n\n            L_a_bias_WL = jax.vmap(\n                lambda i: other_representation_to_body(\n                    C_v̇_WL=C_v̇_WL[i],\n                    C_v_WC=C_v_WC[i],\n                    L_H_C=L_H_C[i],\n                    L_v_LC=L_v_LC[i],\n                )\n            )(jnp.arange(model.number_of_links()))\n\n        case _:\n            raise ValueError(data.velocity_representation)\n\n    # Compute the bias of the 6D momentum derivative.\n    def bias_momentum_derivative_term(\n        link_index: jtp.Int, L_a_bias_WL: jtp.Vector\n    ) -> jtp.Vector:\n\n        # Get the body-fixed 6D inertia matrix.\n        L_M_L = js.link.spatial_inertia(model=model, link_index=link_index)\n\n        # Compute the body-fixed 6D velocity.\n        L_v_WL = js.link.velocity(\n            model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Body\n        )\n\n        # Compute the world-to-link transformations for 6D forces.\n        W_Xf_L = jaxsim.math.Adjoint.from_transform(\n            transform=W_H_L[link_index], inverse=True\n        ).T\n\n        # Compute the contribution of the link to the bias acceleration of the CoM.\n        W_ḣ_bias_link_contribution = W_Xf_L @ (\n            L_M_L @ L_a_bias_WL + jaxsim.math.Cross.vx_star(L_v_WL) @ L_M_L @ L_v_WL\n        )\n\n        return W_ḣ_bias_link_contribution\n\n    # Sum the contributions of all links to the bias acceleration of the CoM.\n    W_ḣ_bias = jax.vmap(bias_momentum_derivative_term)(\n        jnp.arange(model.number_of_links()), L_a_bias_WL\n    ).sum(axis=0)\n\n    # Compute the total mass of the model.\n    m = js.model.total_mass(model=model)\n\n    # Compute the position of the CoM.\n    W_p_CoM = com_position(model=model, data=data)\n\n    match data.velocity_representation:\n\n        # G := G[W] = (W_p_CoM, [W])\n        case VelRepr.Inertial | VelRepr.Mixed:\n\n            W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)\n            GW_Xf_W = jaxsim.math.Adjoint.from_transform(W_H_GW).T\n\n            GW_ḣ_bias = GW_Xf_W @ W_ḣ_bias\n            GW_v̇l_com_bias = GW_ḣ_bias[0:3] / m\n\n            return GW_v̇l_com_bias\n\n        # G := G[B] = (W_p_CoM, [B])\n        case VelRepr.Body:\n\n            GB_Xf_W = jaxsim.math.Adjoint.from_transform(\n                transform=data._base_transform.at[0:3].set(W_p_CoM)\n            ).T\n\n            GB_ḣ_bias = GB_Xf_W @ W_ḣ_bias\n            GB_v̇l_com_bias = GB_ḣ_bias[0:3] / m\n\n            return GB_v̇l_com_bias\n\n        case _:\n            raise ValueError(data.velocity_representation)\n"
  },
  {
    "path": "src/jaxsim/api/common.py",
    "content": "import abc\nimport contextlib\nimport dataclasses\nimport enum\nimport functools\nfrom collections.abc import Callable, Iterator\nfrom typing import ParamSpec, TypeVar\n\nimport jax\nimport jax.numpy as jnp\nimport jax_dataclasses\nfrom jax_dataclasses import Static\n\nimport jaxsim.typing as jtp\nfrom jaxsim.math import Adjoint\nfrom jaxsim.utils import JaxsimDataclass, Mutability\n\ntry:\n    from typing import Self\nexcept ImportError:\n    from typing_extensions import Self\n\n\n_P = ParamSpec(\"_P\")\n_R = TypeVar(\"_R\")\n\n\ndef named_scope(fn, name: str | None = None) -> Callable[_P, _R]:\n    \"\"\"Apply a JAX named scope to a function for improved profiling and clarity.\"\"\"\n\n    @functools.wraps(fn)\n    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:\n        with jax.named_scope(name or fn.__name__):\n            return fn(*args, **kwargs)\n\n    return wrapper\n\n\n@enum.unique\nclass VelRepr(enum.IntEnum):\n    \"\"\"\n    Enumeration of all supported 6D velocity representations.\n    \"\"\"\n\n    Body = enum.auto()\n    Mixed = enum.auto()\n    Inertial = enum.auto()\n\n\n@jax_dataclasses.pytree_dataclass\nclass ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):\n    \"\"\"\n    Base class for model data structures with velocity representation.\n    \"\"\"\n\n    velocity_representation: Static[VelRepr] = dataclasses.field(\n        default=VelRepr.Inertial, kw_only=True\n    )\n\n    @contextlib.contextmanager\n    def switch_velocity_representation(\n        self, velocity_representation: VelRepr\n    ) -> Iterator[Self]:\n        \"\"\"\n        Context manager to temporarily switch the velocity representation.\n\n        Args:\n            velocity_representation: The new velocity representation.\n\n        Yields:\n            The same object with the new velocity representation.\n        \"\"\"\n\n        original_representation = self.velocity_representation\n\n        try:\n\n            # First, we replace the velocity representation.\n            with self.mutable_context(\n                mutability=Mutability.MUTABLE_NO_VALIDATION,\n                restore_after_exception=True,\n            ):\n                self.velocity_representation = velocity_representation\n\n            # Then, we yield the data with changed representation.\n            # We run this in a mutable context with restoration so that any exception\n            # occurring, we restore the original object in case it was modified.\n            with self.mutable_context(\n                mutability=self.mutability(), restore_after_exception=True\n            ):\n                yield self\n\n        finally:\n            with self.mutable_context(\n                mutability=Mutability.MUTABLE_NO_VALIDATION,\n                restore_after_exception=True,\n            ):\n                self.velocity_representation = original_representation\n\n    @staticmethod\n    @functools.partial(jax.jit, static_argnames=[\"other_representation\", \"is_force\"])\n    def inertial_to_other_representation(\n        array: jtp.Array,\n        other_representation: VelRepr,\n        transform: jtp.Matrix,\n        *,\n        is_force: bool,\n    ) -> jtp.Array:\n        r\"\"\"\n        Convert a 6D quantity from inertial-fixed to another representation.\n\n        Args:\n            array: The 6D quantity to convert.\n            other_representation: The representation to convert to.\n            transform:\n                The :math:`W \\mathbf{H}_O` transform, where :math:`O` is the\n                reference frame of the other representation.\n            is_force: Whether the quantity is a 6D force or a 6D velocity.\n\n        Returns:\n            The 6D quantity in the other representation.\n        \"\"\"\n\n        W_array = array\n        W_H_O = transform\n\n        match other_representation:\n\n            case VelRepr.Inertial:\n                return W_array\n\n            case VelRepr.Body:\n\n                if not is_force:\n                    O_Xv_W = Adjoint.from_transform(transform=W_H_O, inverse=True)\n                    O_array = jnp.einsum(\"...ij,...j->...i\", O_Xv_W, W_array)\n\n                else:\n                    O_Xf_W = Adjoint.from_transform(transform=W_H_O).swapaxes(-1, -2)\n                    O_array = jnp.einsum(\"...ij,...j->...i\", O_Xf_W, W_array)\n\n                return O_array\n\n            case VelRepr.Mixed:\n                W_H_OW = W_H_O.at[..., 0:3, 0:3].set(jnp.eye(3))\n\n                if not is_force:\n                    OW_Xv_W = Adjoint.from_transform(transform=W_H_OW, inverse=True)\n                    OW_array = jnp.einsum(\"...ij,...j->...i\", OW_Xv_W, W_array)\n\n                else:\n                    OW_Xf_W = Adjoint.from_transform(transform=W_H_OW).swapaxes(-1, -2)\n                    OW_array = jnp.einsum(\"...ij,...j->...i\", OW_Xf_W, W_array)\n\n                return OW_array\n\n            case _:\n                raise ValueError(other_representation)\n\n    @staticmethod\n    @functools.partial(jax.jit, static_argnames=[\"other_representation\", \"is_force\"])\n    def other_representation_to_inertial(\n        array: jtp.Array,\n        other_representation: VelRepr,\n        transform: jtp.Matrix,\n        *,\n        is_force: bool,\n    ) -> jtp.Array:\n        r\"\"\"\n        Convert a 6D quantity from another representation to inertial-fixed.\n\n        Args:\n            array: The 6D quantity to convert.\n            other_representation: The representation to convert from.\n            transform:\n                The `math:W \\mathbf{H}_O` transform, where `math:O` is the\n                reference frame of the other representation.\n            is_force: Whether the quantity is a 6D force or a 6D velocity.\n\n        Returns:\n            The 6D quantity in the inertial-fixed representation.\n        \"\"\"\n\n        O_array = array\n        W_H_O = transform\n\n        match other_representation:\n            case VelRepr.Inertial:\n                return O_array\n\n            case VelRepr.Body:\n\n                if not is_force:\n                    W_Xv_O = Adjoint.from_transform(W_H_O)\n                    W_array = jnp.einsum(\"...ij,...j->...i\", W_Xv_O, O_array)\n\n                else:\n                    W_Xf_O = Adjoint.from_transform(\n                        transform=W_H_O, inverse=True\n                    ).swapaxes(-1, -2)\n                    W_array = jnp.einsum(\"...ij,...j->...i\", W_Xf_O, O_array)\n\n                return W_array\n\n            case VelRepr.Mixed:\n\n                W_H_OW = W_H_O.at[..., 0:3, 0:3].set(jnp.eye(3))\n\n                if not is_force:\n                    W_Xv_BW = Adjoint.from_transform(W_H_OW)\n                    W_array = jnp.einsum(\"...ij,...j->...i\", W_Xv_BW, O_array)\n\n                else:\n                    W_Xf_BW = Adjoint.from_transform(\n                        transform=W_H_OW, inverse=True\n                    ).swapaxes(-1, -2)\n                    W_array = jnp.einsum(\"...ij,...j->...i\", W_Xf_BW, O_array)\n\n                return W_array\n\n            case _:\n                raise ValueError(other_representation)\n"
  },
  {
    "path": "src/jaxsim/api/contact.py",
    "content": "from __future__ import annotations\n\nimport functools\n\nimport jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.exceptions\nimport jaxsim.typing as jtp\nfrom jaxsim import logging\nfrom jaxsim.math import Adjoint, Cross, Transform\nfrom jaxsim.rbda.contacts import SoftContacts\n\nfrom .common import VelRepr\n\n\n@jax.jit\n@js.common.named_scope\ndef collidable_point_kinematics(\n    model: js.model.JaxSimModel, data: js.data.JaxSimModelData\n) -> tuple[jtp.Matrix, jtp.Matrix]:\n    \"\"\"\n    Compute the position and 3D velocity of the collidable points in the world frame.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The position and velocity of the collidable points in the world frame.\n\n    Note:\n        The collidable point velocity is the plain coordinate derivative of the position.\n        If we attach a frame C = (p_C, [C]) to the collidable point, it corresponds to\n        the linear component of the mixed 6D frame velocity.\n    \"\"\"\n\n    W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(\n        model=model,\n        link_transforms=data._link_transforms,\n        link_velocities=data._link_velocities,\n    )\n\n    return W_p_Ci, W_ṗ_Ci\n\n\n@jax.jit\n@js.common.named_scope\ndef collidable_point_positions(\n    model: js.model.JaxSimModel, data: js.data.JaxSimModelData\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the position of the collidable points in the world frame.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The position of the collidable points in the world frame.\n    \"\"\"\n\n    W_p_Ci, _ = collidable_point_kinematics(model=model, data=data)\n\n    return W_p_Ci\n\n\n@jax.jit\n@js.common.named_scope\ndef collidable_point_velocities(\n    model: js.model.JaxSimModel, data: js.data.JaxSimModelData\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the 3D velocity of the collidable points in the world frame.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The 3D velocity of the collidable points.\n    \"\"\"\n\n    _, W_ṗ_Ci = collidable_point_kinematics(model=model, data=data)\n\n    return W_ṗ_Ci\n\n\n@functools.partial(jax.jit, static_argnames=[\"link_names\"])\n@js.common.named_scope\ndef in_contact(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    link_names: tuple[str, ...] | None = None,\n) -> jtp.Vector:\n    \"\"\"\n    Return whether the links are in contact with the terrain.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        link_names:\n            The names of the links to consider. If None, all links are considered.\n\n    Returns:\n        A boolean vector indicating whether the links are in contact with the terrain.\n    \"\"\"\n\n    if link_names is not None and set(link_names).difference(model.link_names()):\n        raise ValueError(\"One or more link names are not part of the model\")\n\n    # Get the indices of the enabled collidable points.\n    indices_of_enabled_collidable_points = (\n        model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points\n    )\n\n    parent_link_idx_of_enabled_collidable_points = jnp.array(\n        model.kin_dyn_parameters.contact_parameters.body, dtype=int\n    )[indices_of_enabled_collidable_points]\n\n    W_p_Ci = collidable_point_positions(model=model, data=data)\n\n    terrain_height = jax.vmap(lambda x, y: model.terrain.height(x=x, y=y))(\n        W_p_Ci[:, 0], W_p_Ci[:, 1]\n    )\n\n    below_terrain = W_p_Ci[:, 2] <= terrain_height\n\n    link_idxs = (\n        js.link.names_to_idxs(link_names=link_names, model=model)\n        if link_names is not None\n        else jnp.arange(model.number_of_links())\n    )\n\n    links_in_contact = jax.vmap(\n        lambda link_index: jnp.where(\n            parent_link_idx_of_enabled_collidable_points == link_index,\n            below_terrain,\n            jnp.zeros_like(below_terrain, dtype=bool),\n        ).any()\n    )(link_idxs)\n\n    return links_in_contact\n\n\ndef estimate_good_soft_contacts_parameters(\n    *args, **kwargs\n) -> jaxsim.rbda.contacts.ContactParamsTypes:\n    \"\"\"\n    Estimate good soft contacts parameters. Deprecated, use `estimate_good_contact_parameters` instead.\n    \"\"\"\n\n    msg = \"This method is deprecated, please use `{}`.\"\n    logging.warning(msg.format(estimate_good_contact_parameters.__name__))\n    return estimate_good_contact_parameters(*args, **kwargs)\n\n\ndef estimate_good_contact_parameters(\n    model: js.model.JaxSimModel,\n    *,\n    standard_gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,\n    static_friction_coefficient: jtp.FloatLike = 0.5,\n    number_of_active_collidable_points_steady_state: jtp.IntLike = 1,\n    damping_ratio: jtp.FloatLike = 1.0,\n    max_penetration: jtp.FloatLike | None = None,\n) -> jaxsim.rbda.contacts.ContactParamsTypes:\n    \"\"\"\n    Estimate good contact parameters.\n\n    Args:\n        model: The model to consider.\n        standard_gravity: The standard gravity acceleration.\n        static_friction_coefficient: The static friction coefficient.\n        number_of_active_collidable_points_steady_state:\n            The number of active collidable points in steady state.\n        damping_ratio: The damping ratio.\n        max_penetration: The maximum penetration allowed.\n\n    Returns:\n        The estimated good contacts parameters.\n\n    Note:\n        This is primarily a convenience function for soft-like contact models.\n        However, it provides with some good default parameters also for the other ones.\n\n    Note:\n        This method provides a good set of contacts parameters.\n        The user is encouraged to fine-tune the parameters based on the\n        specific application.\n    \"\"\"\n    if max_penetration is None:\n        zero_data = js.data.JaxSimModelData.build(model=model)\n        W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]\n        if model.floating_base():\n            W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]\n            W_pz_CoM = W_pz_CoM - W_pz_C.min()\n\n        # Consider as default a 1% of the model center of mass height.\n        max_penetration = 0.01 * W_pz_CoM\n\n    nc = number_of_active_collidable_points_steady_state\n    return model.contact_model._parameters_class().build_default_from_jaxsim_model(\n        model=model,\n        standard_gravity=standard_gravity,\n        static_friction_coefficient=static_friction_coefficient,\n        max_penetration=max_penetration,\n        number_of_active_collidable_points_steady_state=nc,\n        damping_ratio=damping_ratio,\n    )\n\n\n@jax.jit\n@js.common.named_scope\ndef transforms(model: js.model.JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Array:\n    r\"\"\"\n    Return the pose of the enabled collidable points.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The stacked SE(3) matrices of all enabled collidable points.\n\n    Note:\n        Each collidable point is implicitly associated with a frame\n        :math:`C = ({}^W p_C, [L])`, where :math:`{}^W p_C` is the position of the\n        collidable point and :math:`[L]` is the orientation frame of the link it is\n        rigidly attached to.\n    \"\"\"\n\n    # Get the indices of the enabled collidable points.\n    indices_of_enabled_collidable_points = (\n        model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points\n    )\n\n    parent_link_idx_of_enabled_collidable_points = jnp.array(\n        model.kin_dyn_parameters.contact_parameters.body, dtype=int\n    )[indices_of_enabled_collidable_points]\n\n    # Get the transforms of the parent link of all collidable points.\n    W_H_L = data._link_transforms[parent_link_idx_of_enabled_collidable_points]\n\n    L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[\n        indices_of_enabled_collidable_points\n    ]\n\n    # Build the link-to-point transform from the displacement between the link frame L\n    # and the implicit contact frame C.\n    L_H_C = jax.vmap(jnp.eye(4).at[0:3, 3].set)(L_p_Ci)\n\n    # Compose the work-to-link and link-to-point transforms.\n    return jax.vmap(lambda W_H_Li, L_H_Ci: W_H_Li @ L_H_Ci)(W_H_L, L_H_C)\n\n\n@functools.partial(jax.jit, static_argnames=[\"output_vel_repr\"])\n@js.common.named_scope\ndef jacobian(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    output_vel_repr: VelRepr | None = None,\n) -> jtp.Array:\n    r\"\"\"\n    Return the free-floating Jacobian of the enabled collidable points.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        output_vel_repr:\n            The output velocity representation of the free-floating jacobian.\n\n    Returns:\n        The stacked :math:`6 \\times (6+n)` free-floating jacobians of the frames associated to the\n        enabled collidable points.\n\n    Note:\n        Each collidable point is implicitly associated with a frame\n        :math:`C = ({}^W p_C, [L])`, where :math:`{}^W p_C` is the position of the\n        collidable point and :math:`[L]` is the orientation frame of the link it is\n        rigidly attached to.\n    \"\"\"\n\n    output_vel_repr = (\n        output_vel_repr if output_vel_repr is not None else data.velocity_representation\n    )\n\n    # Get the indices of the enabled collidable points.\n    indices_of_enabled_collidable_points = (\n        model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points\n    )\n\n    parent_link_idx_of_enabled_collidable_points = jnp.array(\n        model.kin_dyn_parameters.contact_parameters.body, dtype=int\n    )[indices_of_enabled_collidable_points]\n\n    # Compute the Jacobians of all links.\n    W_J_WL = js.model.generalized_free_floating_jacobian(\n        model=model, data=data, output_vel_repr=VelRepr.Inertial\n    )\n\n    # Compute the contact Jacobian.\n    # In inertial-fixed output representation, the Jacobian of the parent link is also\n    # the Jacobian of the frame C implicitly associated with the collidable point.\n    W_J_WC = W_J_WL[parent_link_idx_of_enabled_collidable_points]\n\n    # Adjust the output representation.\n    match output_vel_repr:\n\n        case VelRepr.Inertial:\n            O_J_WC = W_J_WC\n\n        case VelRepr.Body:\n\n            W_H_C = transforms(model=model, data=data)\n\n            def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:\n                C_X_W = jaxsim.math.Adjoint.from_transform(\n                    transform=W_H_C, inverse=True\n                )\n                C_J_WC = C_X_W @ W_J_WC\n                return C_J_WC\n\n            O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC)\n\n        case VelRepr.Mixed:\n\n            W_H_C = transforms(model=model, data=data)\n\n            def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:\n\n                W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))\n\n                CW_X_W = jaxsim.math.Adjoint.from_transform(\n                    transform=W_H_CW, inverse=True\n                )\n\n                CW_J_WC = CW_X_W @ W_J_WC\n                return CW_J_WC\n\n            O_J_WC = jax.vmap(mixed_jacobian)(W_H_C, W_J_WC)\n\n        case _:\n            raise ValueError(output_vel_repr)\n\n    return O_J_WC\n\n\n@functools.partial(jax.jit, static_argnames=[\"output_vel_repr\"])\n@js.common.named_scope\ndef jacobian_derivative(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    output_vel_repr: VelRepr | None = None,\n) -> jtp.Matrix:\n    r\"\"\"\n    Compute the derivative of the free-floating jacobian of the enabled collidable points.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        output_vel_repr:\n            The output velocity representation of the free-floating jacobian derivative.\n\n    Returns:\n        The derivative of the :math:`6 \\times (6+n)` free-floating jacobian of the enabled collidable points.\n\n    Note:\n        The input representation of the free-floating jacobian derivative is the active\n        velocity representation.\n    \"\"\"\n\n    output_vel_repr = (\n        output_vel_repr if output_vel_repr is not None else data.velocity_representation\n    )\n\n    indices_of_enabled_collidable_points = (\n        model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points\n    )\n\n    # Get the index of the parent link and the position of the collidable point.\n    parent_link_idx_of_enabled_collidable_points = jnp.array(\n        model.kin_dyn_parameters.contact_parameters.body, dtype=int\n    )[indices_of_enabled_collidable_points]\n\n    L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[\n        indices_of_enabled_collidable_points\n    ]\n\n    # Get the transforms of all the parent links.\n    W_H_Li = data._link_transforms\n\n    # Get the link velocities.\n    W_v_WLi = data._link_velocities\n\n    # =====================================================\n    # Compute quantities to adjust the input representation\n    # =====================================================\n\n    def compute_T(model: js.model.JaxSimModel, X: jtp.Matrix) -> jtp.Matrix:\n        In = jnp.eye(model.dofs())\n        T = jax.scipy.linalg.block_diag(X, In)\n        return T\n\n    def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix:\n        On = jnp.zeros(shape=(model.dofs(), model.dofs()))\n        Ṫ = jax.scipy.linalg.block_diag(Ẋ, On)\n        return Ṫ\n\n    # Compute the operator to change the representation of ν, and its\n    # time derivative.\n    match data.velocity_representation:\n        case VelRepr.Inertial:\n            W_H_W = jnp.eye(4)\n            W_X_W = Adjoint.from_transform(transform=W_H_W)\n            W_Ẋ_W = jnp.zeros((6, 6))\n\n            T = compute_T(model=model, X=W_X_W)\n            Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W)\n\n        case VelRepr.Body:\n            W_H_B = data._base_transform\n            W_X_B = Adjoint.from_transform(transform=W_H_B)\n            B_v_WB = data.base_velocity\n            B_vx_WB = Cross.vx(B_v_WB)\n            W_Ẋ_B = W_X_B @ B_vx_WB\n\n            T = compute_T(model=model, X=W_X_B)\n            Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B)\n\n        case VelRepr.Mixed:\n            W_H_B = data._base_transform\n            W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))\n            W_X_BW = Adjoint.from_transform(transform=W_H_BW)\n            BW_v_WB = data.base_velocity\n            BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))\n            BW_vx_W_BW = Cross.vx(BW_v_W_BW)\n            W_Ẋ_BW = W_X_BW @ BW_vx_W_BW\n\n            T = compute_T(model=model, X=W_X_BW)\n            Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW)\n\n        case _:\n            raise ValueError(data.velocity_representation)\n\n    # =====================================================\n    # Compute quantities to adjust the output representation\n    # =====================================================\n\n    with data.switch_velocity_representation(VelRepr.Inertial):\n        # Compute the Jacobian of the parent link in inertial representation.\n        W_J_WL_W = js.model.generalized_free_floating_jacobian(\n            model=model,\n            data=data,\n        )\n        # Compute the Jacobian derivative of the parent link in inertial representation.\n        W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative(\n            model=model,\n            data=data,\n        )\n\n    def compute_O_J̇_WC_I(\n        L_p_C: jtp.Vector,\n        parent_link_idx: jtp.Int,\n        W_H_L: jtp.Matrix,\n    ) -> jtp.Matrix:\n\n        match output_vel_repr:\n            case VelRepr.Inertial:\n                O_X_W = W_X_W = Adjoint.from_transform(  # noqa: F841\n                    transform=jnp.eye(4)\n                )\n                O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6))  # noqa: F841\n\n            case VelRepr.Body:\n                L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)\n                W_H_C = W_H_L[parent_link_idx] @ L_H_C\n                O_X_W = C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)\n                W_v_WC = W_v_WLi[parent_link_idx]\n                W_vx_WC = Cross.vx(W_v_WC)\n                O_Ẋ_W = C_Ẋ_W = -C_X_W @ W_vx_WC  # noqa: F841\n\n            case VelRepr.Mixed:\n                L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)\n                W_H_C = W_H_L[parent_link_idx] @ L_H_C\n                W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))\n                CW_H_W = Transform.inverse(W_H_CW)\n                O_X_W = CW_X_W = Adjoint.from_transform(transform=CW_H_W)\n                CW_v_WC = CW_X_W @ W_v_WLi[parent_link_idx]\n                W_v_W_CW = jnp.zeros(6).at[0:3].set(CW_v_WC[0:3])\n                W_vx_W_CW = Cross.vx(W_v_W_CW)\n                O_Ẋ_W = CW_Ẋ_W = -CW_X_W @ W_vx_W_CW  # noqa: F841\n\n            case _:\n                raise ValueError(output_vel_repr)\n\n        O_J̇_WC_I = jnp.zeros(shape=(6, 6 + model.dofs()))\n        O_J̇_WC_I += O_Ẋ_W @ W_J_WL_W[parent_link_idx] @ T\n        O_J̇_WC_I += O_X_W @ W_J̇_WL_W[parent_link_idx] @ T\n        O_J̇_WC_I += O_X_W @ W_J_WL_W[parent_link_idx] @ Ṫ\n\n        return O_J̇_WC_I\n\n    O_J̇_WC = jax.vmap(compute_O_J̇_WC_I, in_axes=(0, 0, None))(\n        L_p_Ci, parent_link_idx_of_enabled_collidable_points, W_H_Li\n    )\n\n    return O_J̇_WC\n\n\n@jax.jit\n@js.common.named_scope\ndef link_contact_forces(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    link_forces: jtp.MatrixLike | None = None,\n    joint_torques: jtp.VectorLike | None = None,\n) -> tuple[jtp.Matrix, dict[str, jtp.Matrix]]:\n    \"\"\"\n    Compute the 6D contact forces of all links of the model in inertial representation.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        link_forces:\n            The 6D external forces to apply to the links expressed in inertial representation\n        joint_torques:\n            The joint torques acting on the joints.\n\n    Returns:\n        A `(nL, 6)` array containing the stacked 6D contact forces of the links,\n        expressed in inertial representation.\n    \"\"\"\n\n    # Compute the contact forces for each collidable point with the active contact model.\n    W_f_C, aux_dict = model.contact_model.compute_contact_forces(\n        model=model,\n        data=data,\n        **(\n            dict(link_forces=link_forces, joint_force_references=joint_torques)\n            if not isinstance(model.contact_model, SoftContacts)\n            else {}\n        ),\n    )\n\n    # Compute the 6D forces applied to the links equivalent to the forces applied\n    # to the frames associated to the collidable points.\n    W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C)\n\n    return W_f_L, aux_dict\n\n\ndef link_forces_from_contact_forces(\n    model: js.model.JaxSimModel,\n    *,\n    contact_forces: jtp.MatrixLike,\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the link forces from the contact forces.\n\n    Args:\n        model: The robot model considered by the contact model.\n        contact_forces: The contact forces computed by the contact model.\n\n    Returns:\n        The 6D contact forces applied to the links and expressed in the frame of\n        the velocity representation of data.\n    \"\"\"\n\n    # Get the object storing the contact parameters of the model.\n    contact_parameters = model.kin_dyn_parameters.contact_parameters\n\n    # Extract the indices corresponding to the enabled collidable points.\n    indices_of_enabled_collidable_points = (\n        contact_parameters.indices_of_enabled_collidable_points\n    )\n\n    # Convert the contact forces to a JAX array.\n    W_f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze())\n\n    # Construct the vector defining the parent link index of each collidable point.\n    # We use this vector to sum the 6D forces of all collidable points rigidly\n    # attached to the same link.\n    parent_link_index_of_collidable_points = jnp.array(\n        contact_parameters.body, dtype=int\n    )[indices_of_enabled_collidable_points]\n\n    # Create the mask that associate each collidable point to their parent link.\n    # We use this mask to sum the collidable points to the right link.\n    mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(\n        model.number_of_links()\n    )\n\n    # Sum the forces of all collidable points rigidly attached to a body.\n    # Since the contact forces W_f_C are expressed in the world frame,\n    # we don't need any coordinate transformation.\n    W_f_L = mask.T @ W_f_C\n\n    return W_f_L\n"
  },
  {
    "path": "src/jaxsim/api/data.py",
    "content": "from __future__ import annotations\n\nimport dataclasses\nimport functools\nfrom collections.abc import Sequence\n\ntry:\n    from typing import Self, override\nexcept ImportError:\n    from typing_extensions import override, Self\n\nimport jax\nimport jax.numpy as jnp\nimport jax.scipy.spatial.transform\nimport jax_dataclasses\n\nimport jaxsim.api as js\nimport jaxsim.math\nimport jaxsim.rbda\nimport jaxsim.typing as jtp\n\nfrom . import common\nfrom .common import VelRepr\n\n\n@jax_dataclasses.pytree_dataclass\nclass JaxSimModelData(common.ModelDataWithVelocityRepresentation):\n    \"\"\"\n    Class storing the state of the physics model dynamics.\n\n    Attributes:\n        joint_positions: The vector of joint positions.\n        joint_velocities: The vector of joint velocities.\n        base_position: The 3D position of the base link.\n        base_quaternion: The quaternion defining the orientation of the base link.\n        base_linear_velocity:\n            The linear velocity of the base link in inertial-fixed representation.\n        base_angular_velocity:\n            The angular velocity of the base link in inertial-fixed representation.\n        base_transform: The base transform.\n        joint_transforms: The joint transforms.\n        link_transforms: The link transforms.\n        link_velocities: The link velocities in inertial-fixed representation.\n    \"\"\"\n\n    # Joint state\n    _joint_positions: jtp.Vector\n    _joint_velocities: jtp.Vector\n\n    # Base state\n    _base_quaternion: jtp.Vector\n    _base_linear_velocity: jtp.Vector\n    _base_angular_velocity: jtp.Vector\n    _base_position: jtp.Vector\n\n    # Cached computations.\n    _base_transform: jtp.Matrix = dataclasses.field(repr=False, default=None)\n    _joint_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)\n    _link_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)\n    _link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None)\n\n    # Extended state for soft and rigid contact models.\n    contact_state: dict[str, jtp.Array] = dataclasses.field(default_factory=dict)\n\n    @staticmethod\n    def build(\n        model: js.model.JaxSimModel,\n        base_position: jtp.VectorLike | None = None,\n        base_quaternion: jtp.VectorLike | None = None,\n        joint_positions: jtp.VectorLike | None = None,\n        base_linear_velocity: jtp.VectorLike | None = None,\n        base_angular_velocity: jtp.VectorLike | None = None,\n        joint_velocities: jtp.VectorLike | None = None,\n        contact_state: dict[str, jtp.Array] | None = None,\n        velocity_representation: VelRepr = VelRepr.Mixed,\n    ) -> JaxSimModelData:\n        \"\"\"\n        Create a `JaxSimModelData` object with the given state.\n\n        Args:\n            model: The model for which to create the state.\n            base_position: The base position.\n            base_quaternion: The base orientation as a quaternion.\n            joint_positions: The joint positions.\n            base_linear_velocity:\n                The base linear velocity in the selected representation.\n            base_angular_velocity:\n                The base angular velocity in the selected representation.\n            joint_velocities: The joint velocities.\n            velocity_representation: The velocity representation to use. It defaults to mixed if not provided.\n            contact_state: The optional contact state.\n\n        Returns:\n            A `JaxSimModelData` initialized with the given state.\n        \"\"\"\n\n        base_position = jnp.array(\n            base_position if base_position is not None else jnp.zeros(3),\n            dtype=float,\n        ).squeeze()\n\n        base_quaternion = jnp.array(\n            (\n                base_quaternion\n                if base_quaternion is not None\n                else jnp.array([1.0, 0, 0, 0])\n            ),\n            dtype=float,\n        ).squeeze()\n\n        base_linear_velocity = jnp.array(\n            base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3),\n            dtype=float,\n        ).squeeze()\n\n        base_angular_velocity = jnp.array(\n            (\n                base_angular_velocity\n                if base_angular_velocity is not None\n                else jnp.zeros(3)\n            ),\n            dtype=float,\n        ).squeeze()\n\n        joint_positions = jnp.atleast_1d(\n            jnp.array(\n                (\n                    joint_positions\n                    if joint_positions is not None\n                    else jnp.zeros(model.dofs())\n                ),\n                dtype=float,\n            ).squeeze()\n        )\n\n        joint_velocities = jnp.atleast_1d(\n            jnp.array(\n                (\n                    joint_velocities\n                    if joint_velocities is not None\n                    else jnp.zeros(model.dofs())\n                ),\n                dtype=float,\n            ).squeeze()\n        )\n\n        W_H_B = jaxsim.math.Transform.from_quaternion_and_translation(\n            translation=base_position, quaternion=base_quaternion\n        )\n\n        W_v_WB = JaxSimModelData.other_representation_to_inertial(\n            array=jnp.hstack([base_linear_velocity, base_angular_velocity]),\n            other_representation=velocity_representation,\n            transform=W_H_B,\n            is_force=False,\n        ).astype(float)\n\n        joint_transforms = model.kin_dyn_parameters.joint_transforms(\n            joint_positions=joint_positions, base_transform=W_H_B\n        )\n\n        link_transforms, link_velocities_inertial = (\n            jaxsim.rbda.forward_kinematics_model(\n                model=model,\n                base_position=base_position,\n                base_quaternion=base_quaternion,\n                joint_positions=joint_positions,\n                base_linear_velocity_inertial=W_v_WB[0:3],\n                base_angular_velocity_inertial=W_v_WB[3:6],\n                joint_velocities=joint_velocities,\n                joint_transforms=joint_transforms,\n            )\n        )\n\n        contact_state = contact_state or {}\n\n        if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts):\n            contact_state[\"tangential_deformation\"] = contact_state.get(\n                \"tangential_deformation\",\n                jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.point),\n            )\n\n        model_data = JaxSimModelData(\n            velocity_representation=velocity_representation,\n            _base_quaternion=base_quaternion,\n            _base_position=base_position,\n            _joint_positions=joint_positions,\n            _base_linear_velocity=W_v_WB[0:3],\n            _base_angular_velocity=W_v_WB[3:6],\n            _joint_velocities=joint_velocities,\n            _base_transform=W_H_B,\n            _joint_transforms=joint_transforms,\n            _link_transforms=link_transforms,\n            _link_velocities=link_velocities_inertial,\n            contact_state=contact_state,\n        )\n\n        if not model_data.valid(model=model):\n            raise ValueError(\n                \"The built state is not compatible with the model.\", model_data\n            )\n\n        return model_data\n\n    @staticmethod\n    def zero(\n        model: js.model.JaxSimModel,\n        velocity_representation: VelRepr = VelRepr.Mixed,\n    ) -> JaxSimModelData:\n        \"\"\"\n        Create a `JaxSimModelData` object with zero state.\n\n        Args:\n            model: The model for which to create the state.\n            velocity_representation: The velocity representation to use. It defaults to mixed if not provided.\n\n        Returns:\n            A `JaxSimModelData` initialized with zero state.\n        \"\"\"\n        return JaxSimModelData.build(\n            model=model, velocity_representation=velocity_representation\n        )\n\n    # ==================\n    # Extract quantities\n    # ==================\n\n    @property\n    def joint_positions(self) -> jtp.Vector:\n        \"\"\"\n        Get the joint positions.\n\n        Returns:\n            The joint positions.\n        \"\"\"\n        return self._joint_positions\n\n    @property\n    def joint_velocities(self) -> jtp.Vector:\n        \"\"\"\n        Get the joint velocities.\n\n        Returns:\n            The joint velocities.\n        \"\"\"\n        return self._joint_velocities\n\n    @property\n    def base_quaternion(self) -> jtp.Vector:\n        \"\"\"\n        Get the base quaternion.\n\n        Returns:\n            The base quaternion.\n        \"\"\"\n        return self._base_quaternion\n\n    @property\n    def base_position(self) -> jtp.Vector:\n        \"\"\"\n        Get the base position.\n\n        Returns:\n            The base position.\n        \"\"\"\n        return self._base_position\n\n    @property\n    def base_orientation(self) -> jtp.Matrix:\n        \"\"\"\n        Get the base orientation.\n\n        Returns:\n            The base orientation.\n        \"\"\"\n\n        # Extract the base quaternion.\n        W_Q_B = self.base_quaternion\n\n        # Always normalize the quaternion to avoid numerical issues.\n        # If the active scheme does not integrate the quaternion on its manifold,\n        # we introduce a Baumgarte stabilization to let the quaternion converge to\n        # a unit quaternion. In this case, it is not guaranteed that the quaternion\n        # stored in the state is a unit quaternion.\n        norm = jaxsim.math.safe_norm(W_Q_B, axis=-1, keepdims=True)\n        W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))\n        return W_Q_B\n\n    @property\n    def base_velocity(self) -> jtp.Vector:\n        \"\"\"\n        Get the base 6D velocity.\n\n        Returns:\n            The base 6D velocity in the active representation.\n        \"\"\"\n\n        W_v_WB = jnp.concatenate(\n            [self._base_linear_velocity, self._base_angular_velocity], axis=-1\n        )\n\n        W_H_B = self._base_transform\n\n        return (\n            JaxSimModelData.inertial_to_other_representation(\n                array=W_v_WB,\n                other_representation=self.velocity_representation,\n                transform=W_H_B,\n                is_force=False,\n            )\n            .squeeze()\n            .astype(float)\n        )\n\n    @property\n    def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:\n        r\"\"\"\n        Get the generalized position\n        :math:`\\mathbf{q} = ({}^W \\mathbf{H}_B, \\mathbf{s}) \\in \\text{SO}(3) \\times \\mathbb{R}^n`.\n\n        Returns:\n            A tuple containing the base transform and the joint positions.\n        \"\"\"\n\n        return self._base_transform, self.joint_positions\n\n    @property\n    def generalized_velocity(self) -> jtp.Vector:\n        r\"\"\"\n        Get the generalized velocity.\n\n        :math:`\\boldsymbol{\\nu} = (\\boldsymbol{v}_{W,B};\\, \\boldsymbol{\\omega}_{W,B};\\, \\mathbf{s}) \\in \\mathbb{R}^{6+n}`\n\n        Returns:\n            The generalized velocity in the active representation.\n        \"\"\"\n\n        return (\n            jnp.hstack([self.base_velocity, self.joint_velocities])\n            .squeeze()\n            .astype(float)\n        )\n\n    @property\n    def base_transform(self) -> jtp.Matrix:\n        \"\"\"\n        Get the base transform.\n\n        Returns:\n            The base transform.\n        \"\"\"\n        return self._base_transform\n\n    # ================\n    # Store quantities\n    # ================\n\n    @js.common.named_scope\n    @jax.jit\n    def reset_base_quaternion(\n        self, model: js.model.JaxSimModel, base_quaternion: jtp.VectorLike\n    ) -> Self:\n        \"\"\"\n        Reset the base quaternion.\n\n        Args:\n            model: The JaxSim model to use.\n            base_quaternion: The base orientation as a quaternion.\n\n        Returns:\n            The updated `JaxSimModelData` object.\n        \"\"\"\n\n        W_Q_B = jnp.array(base_quaternion, dtype=float)\n\n        norm = jaxsim.math.safe_norm(W_Q_B, axis=-1)\n        W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0))\n\n        return self.replace(model=model, base_quaternion=W_Q_B)\n\n    @js.common.named_scope\n    @jax.jit\n    def reset_base_pose(\n        self, model: js.model.JaxSimModel, base_pose: jtp.MatrixLike\n    ) -> Self:\n        \"\"\"\n        Reset the base pose.\n\n        Args:\n            model: The JaxSim model to use.\n            base_pose: The base pose as an SE(3) matrix.\n\n        Returns:\n            The updated `JaxSimModelData` object.\n        \"\"\"\n\n        base_pose = jnp.array(base_pose)\n        W_p_B = base_pose[0:3, 3]\n        W_Q_B = jaxsim.math.Quaternion.from_dcm(dcm=base_pose[0:3, 0:3])\n        return self.replace(\n            model=model,\n            base_position=W_p_B,\n            base_quaternion=W_Q_B,\n        )\n\n    @override\n    def replace(\n        self,\n        model: js.model.JaxSimModel,\n        joint_positions: jtp.Vector | None = None,\n        joint_velocities: jtp.Vector | None = None,\n        base_quaternion: jtp.Vector | None = None,\n        base_linear_velocity: jtp.Vector | None = None,\n        base_angular_velocity: jtp.Vector | None = None,\n        base_position: jtp.Vector | None = None,\n        *,\n        contact_state: dict[str, jtp.Array] | None = None,\n        validate: bool = False,\n    ) -> Self:\n        \"\"\"\n        Replace the attributes of the `JaxSimModelData` object.\n        \"\"\"\n\n        if joint_positions is None:\n            joint_positions = self.joint_positions\n        if joint_velocities is None:\n            joint_velocities = self.joint_velocities\n        if base_quaternion is None:\n            base_quaternion = self.base_quaternion\n        if base_position is None:\n            base_position = self.base_position\n        if contact_state is None:\n            contact_state = self.contact_state\n\n        # Normalize the quaternion to avoid numerical issues.\n        base_quaternion_norm = jaxsim.math.safe_norm(\n            base_quaternion, axis=-1, keepdims=True\n        )\n        base_quaternion = base_quaternion / jnp.where(\n            base_quaternion_norm == 0, 1.0, base_quaternion_norm\n        )\n\n        joint_positions = jnp.atleast_2d(joint_positions.squeeze()).astype(float)\n        joint_velocities = jnp.atleast_2d(joint_velocities.squeeze()).astype(float)\n        base_quaternion = jnp.atleast_2d(base_quaternion.squeeze()).astype(float)\n        base_position = jnp.atleast_2d(base_position.squeeze()).astype(float)\n\n        base_transform = jaxsim.math.Transform.from_quaternion_and_translation(\n            translation=base_position, quaternion=base_quaternion\n        ).reshape((-1, 4, 4))\n\n        joint_transforms = jax.vmap(model.kin_dyn_parameters.joint_transforms)(\n            joint_positions=joint_positions,\n            base_transform=base_transform,\n        )\n\n        if base_linear_velocity is None and base_angular_velocity is None:\n            base_linear_velocity_inertial = self._base_linear_velocity\n            base_angular_velocity_inertial = self._base_angular_velocity\n        else:\n            if base_linear_velocity is None:\n                base_linear_velocity = self.base_velocity[:3]\n            if base_angular_velocity is None:\n                base_angular_velocity = self.base_velocity[3:]\n\n            base_linear_velocity = jnp.atleast_1d(base_linear_velocity.squeeze())\n            base_angular_velocity = jnp.atleast_1d(base_angular_velocity.squeeze())\n\n            W_v_WB = JaxSimModelData.other_representation_to_inertial(\n                array=jnp.hstack([base_linear_velocity, base_angular_velocity]),\n                other_representation=self.velocity_representation,\n                transform=base_transform,\n                is_force=False,\n            ).astype(float)\n\n            base_linear_velocity_inertial, base_angular_velocity_inertial = (\n                W_v_WB[..., :3],\n                W_v_WB[..., 3:],\n            )\n\n        link_transforms, link_velocities = jax.vmap(\n            jaxsim.rbda.forward_kinematics_model, in_axes=(None,)\n        )(\n            model,\n            base_position=base_position,\n            base_quaternion=base_quaternion,\n            joint_positions=joint_positions,\n            joint_velocities=joint_velocities,\n            base_linear_velocity_inertial=jnp.atleast_2d(base_linear_velocity_inertial),\n            base_angular_velocity_inertial=jnp.atleast_2d(\n                base_angular_velocity_inertial\n            ),\n            joint_transforms=joint_transforms,\n        )\n\n        # Adjust the output shapes.\n        joint_positions = joint_positions.reshape(self._joint_positions.shape)\n        joint_velocities = joint_velocities.reshape(self._joint_velocities.shape)\n        base_quaternion = base_quaternion.reshape(self._base_quaternion.shape)\n        base_linear_velocity_inertial = base_linear_velocity_inertial.reshape(\n            self._base_linear_velocity.shape\n        )\n        base_angular_velocity_inertial = base_angular_velocity_inertial.reshape(\n            self._base_angular_velocity.shape\n        )\n        base_position = base_position.reshape(self._base_position.shape)\n        base_transform = base_transform.reshape(self._base_transform.shape)\n        joint_transforms = joint_transforms.reshape(self._joint_transforms.shape)\n        link_transforms = link_transforms.reshape(self._link_transforms.shape)\n        link_velocities = link_velocities.reshape(self._link_velocities.shape)\n\n        return super().replace(\n            _joint_positions=joint_positions,\n            _joint_velocities=joint_velocities,\n            _base_quaternion=base_quaternion,\n            _base_linear_velocity=base_linear_velocity_inertial,\n            _base_angular_velocity=base_angular_velocity_inertial,\n            _base_position=base_position,\n            _base_transform=base_transform,\n            _joint_transforms=joint_transforms,\n            _link_transforms=link_transforms,\n            _link_velocities=link_velocities,\n            contact_state=contact_state,\n            validate=validate,\n        )\n\n    def valid(self, model: js.model.JaxSimModel) -> bool:\n        \"\"\"\n        Check if the `JaxSimModelData` is valid for a given `JaxSimModel`.\n\n        Args:\n            model: The `JaxSimModel` to validate the `JaxSimModelData` against.\n\n        Returns:\n            `True` if the `JaxSimModelData` is valid for the given model,\n            `False` otherwise.\n        \"\"\"\n        if self._joint_positions.shape != (model.dofs(),):\n            return False\n        if self._joint_velocities.shape != (model.dofs(),):\n            return False\n        if self._base_position.shape != (3,):\n            return False\n        if self._base_quaternion.shape != (4,):\n            return False\n        if self._base_linear_velocity.shape != (3,):\n            return False\n        if self._base_angular_velocity.shape != (3,):\n            return False\n\n        return True\n\n\n@functools.partial(jax.jit, static_argnames=[\"velocity_representation\", \"base_rpy_seq\"])\ndef random_model_data(\n    model: js.model.JaxSimModel,\n    *,\n    key: jax.Array | None = None,\n    velocity_representation: VelRepr | None = None,\n    base_pos_bounds: tuple[\n        jtp.FloatLike | Sequence[jtp.FloatLike],\n        jtp.FloatLike | Sequence[jtp.FloatLike],\n    ] = ((-1, -1, 0.5), 1.0),\n    base_rpy_bounds: tuple[\n        jtp.FloatLike | Sequence[jtp.FloatLike],\n        jtp.FloatLike | Sequence[jtp.FloatLike],\n    ] = (-jnp.pi, jnp.pi),\n    base_rpy_seq: str = \"XYZ\",\n    joint_pos_bounds: (\n        tuple[\n            jtp.FloatLike | Sequence[jtp.FloatLike],\n            jtp.FloatLike | Sequence[jtp.FloatLike],\n        ]\n        | None\n    ) = None,\n    base_vel_lin_bounds: tuple[\n        jtp.FloatLike | Sequence[jtp.FloatLike],\n        jtp.FloatLike | Sequence[jtp.FloatLike],\n    ] = (-1.0, 1.0),\n    base_vel_ang_bounds: tuple[\n        jtp.FloatLike | Sequence[jtp.FloatLike],\n        jtp.FloatLike | Sequence[jtp.FloatLike],\n    ] = (-1.0, 1.0),\n    joint_vel_bounds: tuple[\n        jtp.FloatLike | Sequence[jtp.FloatLike],\n        jtp.FloatLike | Sequence[jtp.FloatLike],\n    ] = (-1.0, 1.0),\n) -> JaxSimModelData:\n    \"\"\"\n    Randomly generate a `JaxSimModelData` object.\n\n    Args:\n        model: The target model for the random data.\n        key: The random key.\n        velocity_representation: The velocity representation to use.\n        base_pos_bounds: The bounds for the base position.\n        base_rpy_bounds:\n            The bounds for the euler angles used to build the base orientation.\n        base_rpy_seq:\n            The sequence of axes for rotation (using `Rotation` from scipy).\n        joint_pos_bounds:\n            The bounds for the joint positions (reading the joint limits if None).\n        base_vel_lin_bounds: The bounds for the base linear velocity.\n        base_vel_ang_bounds: The bounds for the base angular velocity.\n        joint_vel_bounds: The bounds for the joint velocities.\n\n    Returns:\n        A `JaxSimModelData` object with random data.\n    \"\"\"\n\n    key = key if key is not None else jax.random.PRNGKey(seed=0)\n    k1, k2, k3, k4, k5, k6 = jax.random.split(key, num=6)\n\n    p_min = jnp.array(base_pos_bounds[0], dtype=float)\n    p_max = jnp.array(base_pos_bounds[1], dtype=float)\n    rpy_min = jnp.array(base_rpy_bounds[0], dtype=float)\n    rpy_max = jnp.array(base_rpy_bounds[1], dtype=float)\n    v_min = jnp.array(base_vel_lin_bounds[0], dtype=float)\n    v_max = jnp.array(base_vel_lin_bounds[1], dtype=float)\n    ω_min = jnp.array(base_vel_ang_bounds[0], dtype=float)\n    ω_max = jnp.array(base_vel_ang_bounds[1], dtype=float)\n    ṡ_min, ṡ_max = joint_vel_bounds\n\n    base_position = jax.random.uniform(key=k1, shape=(3,), minval=p_min, maxval=p_max)\n\n    base_quaternion = jaxsim.math.Quaternion.to_wxyz(\n        xyzw=jax.scipy.spatial.transform.Rotation.from_euler(\n            seq=base_rpy_seq,\n            angles=jax.random.uniform(\n                key=k2, shape=(3,), minval=rpy_min, maxval=rpy_max\n            ),\n        ).as_quat()\n    )\n\n    (\n        joint_positions,\n        joint_velocities,\n        base_linear_velocity,\n        base_angular_velocity,\n    ) = (None,) * 4\n\n    if model.number_of_joints() > 0:\n\n        s_min, s_max = (\n            jnp.array(joint_pos_bounds, dtype=float)\n            if joint_pos_bounds is not None\n            else (None, None)\n        )\n\n        joint_positions = (\n            js.joint.random_joint_positions(model=model, key=k3)\n            if (s_min is None or s_max is None)\n            else jax.random.uniform(\n                key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max\n            )\n        )\n\n        joint_velocities = jax.random.uniform(\n            key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max\n        )\n\n    if model.floating_base():\n        base_linear_velocity = jax.random.uniform(\n            key=k5, shape=(3,), minval=v_min, maxval=v_max\n        )\n\n        base_angular_velocity = jax.random.uniform(\n            key=k6, shape=(3,), minval=ω_min, maxval=ω_max\n        )\n\n    return JaxSimModelData.build(\n        model=model,\n        base_position=base_position,\n        base_quaternion=base_quaternion,\n        joint_positions=joint_positions,\n        joint_velocities=joint_velocities,\n        base_linear_velocity=base_linear_velocity,\n        base_angular_velocity=base_angular_velocity,\n        **(\n            {\"velocity_representation\": velocity_representation}\n            if velocity_representation is not None\n            else {}\n        ),\n    )\n"
  },
  {
    "path": "src/jaxsim/api/frame.py",
    "content": "import functools\nfrom collections.abc import Sequence\n\nimport jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim import exceptions\nfrom jaxsim.math import Adjoint, Cross\n\nfrom .common import VelRepr\n\n# =======================\n# Index-related functions\n# =======================\n\n\n@jax.jit\n@js.common.named_scope\ndef idx_of_parent_link(\n    model: js.model.JaxSimModel, *, frame_index: jtp.IntLike\n) -> jtp.Int:\n    \"\"\"\n    Get the index of the link to which the frame is rigidly attached.\n\n    Args:\n        model: The model to consider.\n        frame_index: The index of the frame.\n\n    Returns:\n        The index of the frame's parent link.\n    \"\"\"\n\n    n_l = model.number_of_links()\n    n_f = len(model.frame_names())\n\n    exceptions.raise_value_error_if(\n        condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),\n        msg=\"Invalid frame index '{idx}'\",\n        idx=frame_index,\n    )\n\n    return jnp.array(model.kin_dyn_parameters.frame_parameters.body)[\n        frame_index - model.number_of_links()\n    ]\n\n\n@functools.partial(jax.jit, static_argnames=\"frame_name\")\n@js.common.named_scope\ndef name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> jtp.Int:\n    \"\"\"\n    Convert the name of a frame to its index.\n\n    Args:\n        model: The model to consider.\n        frame_name: The name of the frame.\n\n    Returns:\n        The index of the frame.\n    \"\"\"\n\n    if frame_name not in model.frame_names():\n        raise ValueError(f\"Frame '{frame_name}' not found in the model.\")\n\n    return (\n        jnp.array(\n            model.number_of_links()\n            + model.kin_dyn_parameters.frame_parameters.name.index(frame_name)\n        )\n        .astype(int)\n        .squeeze()\n    )\n\n\ndef idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str:\n    \"\"\"\n    Convert the index of a frame to its name.\n\n    Args:\n        model: The model to consider.\n        frame_index: The index of the frame.\n\n    Returns:\n        The name of the frame.\n    \"\"\"\n\n    n_l = model.number_of_links()\n    n_f = len(model.frame_names())\n\n    exceptions.raise_value_error_if(\n        condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),\n        msg=\"Invalid frame index '{idx}'\",\n        idx=frame_index,\n    )\n\n    return model.kin_dyn_parameters.frame_parameters.name[\n        frame_index - model.number_of_links()\n    ]\n\n\n@functools.partial(jax.jit, static_argnames=[\"frame_names\"])\n@js.common.named_scope\ndef names_to_idxs(\n    model: js.model.JaxSimModel, *, frame_names: Sequence[str]\n) -> jax.Array:\n    \"\"\"\n    Convert a sequence of frame names to their corresponding indices.\n\n    Args:\n        model: The model to consider.\n        frame_names: The names of the frames.\n\n    Returns:\n        The indices of the frames.\n    \"\"\"\n\n    return jnp.array(\n        [name_to_idx(model=model, frame_name=name) for name in frame_names]\n    ).astype(int)\n\n\ndef idxs_to_names(\n    model: js.model.JaxSimModel, *, frame_indices: Sequence[jtp.IntLike]\n) -> tuple[str, ...]:\n    \"\"\"\n    Convert a sequence of frame indices to their corresponding names.\n\n    Args:\n        model: The model to consider.\n        frame_indices: The indices of the frames.\n\n    Returns:\n        The names of the frames.\n    \"\"\"\n\n    return tuple(idx_to_name(model=model, frame_index=idx) for idx in frame_indices)\n\n\n# ==========\n# Frame APIs\n# ==========\n\n\n@jax.jit\n@js.common.named_scope\ndef transform(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    frame_index: jtp.IntLike,\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the SE(3) transform from the world frame to the specified frame.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        frame_index: The index of the frame for which the transform is requested.\n\n    Returns:\n        The 4x4 matrix representing the transform.\n    \"\"\"\n\n    n_l = model.number_of_links()\n    n_f = len(model.frame_names())\n\n    exceptions.raise_value_error_if(\n        condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),\n        msg=\"Invalid frame index '{idx}'\",\n        idx=frame_index,\n    )\n\n    # Compute the necessary transforms.\n    L = idx_of_parent_link(model=model, frame_index=frame_index)\n    W_H_L = js.link.transform(model=model, data=data, link_index=L)\n\n    # Get the static frame pose wrt the parent link.\n    L_H_F = model.kin_dyn_parameters.frame_parameters.transform[\n        frame_index - model.number_of_links()\n    ]\n\n    # Combine the transforms computing the frame pose.\n    return W_H_L @ L_H_F\n\n\n@functools.partial(jax.jit, static_argnames=[\"output_vel_repr\"])\n@js.common.named_scope\ndef velocity(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    frame_index: jtp.IntLike,\n    output_vel_repr: VelRepr | None = None,\n) -> jtp.Vector:\n    \"\"\"\n    Compute the 6D velocity of the frame.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        frame_index: The index of the frame.\n        output_vel_repr:\n            The output velocity representation of the frame velocity.\n\n    Returns:\n        The 6D velocity of the frame in the specified velocity representation.\n    \"\"\"\n    n_l = model.number_of_links()\n    n_f = model.number_of_frames()\n\n    exceptions.raise_value_error_if(\n        condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),\n        msg=\"Invalid frame index '{idx}'\",\n        idx=frame_index,\n    )\n\n    output_vel_repr = (\n        output_vel_repr if output_vel_repr is not None else data.velocity_representation\n    )\n\n    # Get the frame jacobian having I as input representation (taken from data)\n    # and O as output representation, specified by the user (or taken from data).\n    O_J_WF_I = jacobian(\n        model=model,\n        data=data,\n        frame_index=frame_index,\n        output_vel_repr=output_vel_repr,\n    )\n\n    # Get the generalized velocity in the input velocity representation.\n    I_ν = data.generalized_velocity\n\n    # Compute the frame velocity in the output velocity representation.\n    return O_J_WF_I @ I_ν\n\n\n@functools.partial(jax.jit, static_argnames=[\"output_vel_repr\"])\n@js.common.named_scope\ndef jacobian(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    frame_index: jtp.IntLike,\n    output_vel_repr: VelRepr | None = None,\n) -> jtp.Matrix:\n    r\"\"\"\n    Compute the free-floating jacobian of the frame.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        frame_index: The index of the frame.\n        output_vel_repr:\n            The output velocity representation of the free-floating jacobian.\n\n    Returns:\n        The :math:`6 \\times (6+n)` free-floating jacobian of the frame.\n\n    Note:\n        The input representation of the free-floating jacobian is the active\n        velocity representation.\n    \"\"\"\n\n    n_l = model.number_of_links()\n    n_f = model.number_of_frames()\n\n    exceptions.raise_value_error_if(\n        condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),\n        msg=\"Invalid frame index '{idx}'\",\n        idx=frame_index,\n    )\n\n    output_vel_repr = (\n        output_vel_repr if output_vel_repr is not None else data.velocity_representation\n    )\n\n    # Get the index of the parent link.\n    L = idx_of_parent_link(model=model, frame_index=frame_index)\n\n    # Compute only the parent-link body Jacobian.\n    L_J_WL = js.link.jacobian(\n        model=model,\n        data=data,\n        link_index=L,\n        output_vel_repr=VelRepr.Body,\n    )\n    W_H_L = data._link_transforms[L]\n    L_H_F = model.kin_dyn_parameters.frame_parameters.transform[\n        frame_index - model.number_of_links()\n    ]\n    L_p_F = L_H_F[0:3, 3]\n\n    # Adjust the output representation.\n    match output_vel_repr:\n        case VelRepr.Inertial:\n            W_X_L = Adjoint.from_rotation_and_translation(\n                rotation=W_H_L[0:3, 0:3],\n                translation=W_H_L[0:3, 3],\n            )\n            W_J_WL = W_X_L @ L_J_WL\n            O_J_WL_I = W_J_WL\n\n        case VelRepr.Body:\n            F_X_L = Adjoint.from_rotation_and_translation(\n                rotation=L_H_F[0:3, 0:3],\n                translation=L_p_F,\n                inverse=True,\n            )\n            F_J_WL = F_X_L @ L_J_WL\n            O_J_WL_I = F_J_WL\n\n        case VelRepr.Mixed:\n            W_R_L = W_H_L[0:3, 0:3]\n            FW_X_L = Adjoint.from_rotation_and_translation(\n                rotation=W_R_L,\n                translation=-W_R_L @ L_p_F,\n            )\n            FW_J_WL = FW_X_L @ L_J_WL\n            O_J_WL_I = FW_J_WL\n\n        case _:\n            raise ValueError(output_vel_repr)\n\n    return O_J_WL_I\n\n\n@functools.partial(jax.jit, static_argnames=[\"output_vel_repr\"])\n@js.common.named_scope\ndef jacobian_derivative(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    frame_index: jtp.IntLike,\n    output_vel_repr: VelRepr | None = None,\n) -> jtp.Matrix:\n    r\"\"\"\n    Compute the derivative of the free-floating jacobian of the frame.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        frame_index: The index of the frame.\n        output_vel_repr:\n            The output velocity representation of the free-floating jacobian derivative.\n\n    Returns:\n        The derivative of the :math:`6 \\times (6+n)` free-floating jacobian of the frame.\n\n    Note:\n        The input representation of the free-floating jacobian derivative is the active\n        velocity representation.\n    \"\"\"\n\n    n_l = model.number_of_links()\n    n_f = len(model.frame_names())\n\n    exceptions.raise_value_error_if(\n        condition=jnp.array([frame_index < n_l, frame_index >= n_l + n_f]).any(),\n        msg=\"Invalid frame index '{idx}'\",\n        idx=frame_index,\n    )\n\n    output_vel_repr = (\n        output_vel_repr if output_vel_repr is not None else data.velocity_representation\n    )\n\n    # Get the index of the parent link.\n    L = idx_of_parent_link(model=model, frame_index=frame_index)\n    W_J_WL_I = js.link.jacobian(\n        model=model,\n        data=data,\n        link_index=L,\n        output_vel_repr=VelRepr.Inertial,\n    )\n    W_J̇_WL_I = js.link.jacobian_derivative(\n        model=model,\n        data=data,\n        link_index=L,\n        output_vel_repr=VelRepr.Inertial,\n    )\n\n    W_H_L = data._link_transforms[L]\n    L_H_F = model.kin_dyn_parameters.frame_parameters.transform[\n        frame_index - model.number_of_links()\n    ]\n    W_H_F = W_H_L @ L_H_F\n\n    # =====================================================\n    # Compute quantities to adjust the output representation\n    # =====================================================\n\n    W_v_WF = W_J_WL_I @ data.generalized_velocity\n\n    match output_vel_repr:\n        case VelRepr.Inertial:\n            O_X_W = jnp.eye(6, dtype=W_H_F.dtype)\n            O_Ẋ_W = jnp.zeros((6, 6), dtype=W_H_F.dtype)\n\n        case VelRepr.Body:\n            O_X_W = Adjoint.from_rotation_and_translation(\n                rotation=W_H_F[0:3, 0:3],\n                translation=W_H_F[0:3, 3],\n                inverse=True,\n            )\n            O_Ẋ_W = -O_X_W @ Cross.vx(W_v_WF)\n\n        case VelRepr.Mixed:\n            O_X_W = Adjoint.from_rotation_and_translation(\n                rotation=jnp.eye(3, dtype=W_H_F.dtype),\n                translation=W_H_F[0:3, 3],\n                inverse=True,\n            )\n            FW_v_WF = O_X_W @ W_v_WF\n            W_v_W_FW = FW_v_WF.at[3:6].set(jnp.zeros_like(FW_v_WF[3:6]))\n            O_Ẋ_W = -O_X_W @ Cross.vx(W_v_W_FW)\n\n        case _:\n            raise ValueError(output_vel_repr)\n\n    O_J̇_WF_I = O_Ẋ_W @ W_J_WL_I\n    O_J̇_WF_I += O_X_W @ W_J̇_WL_I\n\n    return O_J̇_WF_I\n"
  },
  {
    "path": "src/jaxsim/api/integrators.py",
    "content": "import dataclasses\nfrom collections.abc import Callable\n\nimport jax\nimport jax.numpy as jnp\n\nimport jaxsim\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim.api.data import JaxSimModelData\nfrom jaxsim.math import Skew\n\n\ndef semi_implicit_euler_integration(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    link_forces: jtp.Vector,\n    joint_torques: jtp.Vector,\n) -> JaxSimModelData:\n    \"\"\"Integrate the system state using the semi-implicit Euler method.\"\"\"\n\n    with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):\n\n        # Compute the system acceleration\n        W_v̇_WB, s̈, contact_state_derivative = js.ode.system_acceleration(\n            model=model,\n            data=data,\n            link_forces=link_forces,\n            joint_torques=joint_torques,\n        )\n\n        dt = model.time_step\n\n        # Compute the new generalized velocity.\n        new_generalized_acceleration = jnp.hstack([W_v̇_WB, s̈])\n        new_generalized_velocity = (\n            data.generalized_velocity + dt * new_generalized_acceleration\n        )\n\n        # Extract the new base and joint velocities.\n        W_v_B = new_generalized_velocity[0:6]\n        ṡ = new_generalized_velocity[6:]\n\n        # Compute the new base position and orientation.\n        W_ω_WB = new_generalized_velocity[3:6]\n\n        # To obtain the derivative of the base position, we need to subtract\n        # the skew-symmetric matrix of the base angular velocity times the base position.\n        # See: S. Traversaro and A. Saccon, “Multibody Dynamics Notation (Version 2), pg.9\n        W_ṗ_B = new_generalized_velocity[0:3] + Skew.wedge(W_ω_WB) @ data.base_position\n\n        W_Q̇_B = jaxsim.math.Quaternion.derivative(\n            quaternion=data.base_orientation,\n            omega=W_ω_WB,\n            omega_in_body_fixed=False,\n        ).squeeze()\n\n        W_p_B = data.base_position + dt * W_ṗ_B\n        W_Q_B = data.base_orientation + dt * W_Q̇_B\n\n        base_quaternion_norm = jaxsim.math.safe_norm(W_Q_B, axis=-1)\n\n        W_Q_B = W_Q_B / jnp.where(base_quaternion_norm == 0, 1.0, base_quaternion_norm)\n\n        s = data.joint_positions + dt * ṡ\n\n        integrated_contact_state = jax.tree.map(\n            lambda x, x_dot: x + dt * x_dot,\n            data.contact_state,\n            contact_state_derivative,\n        )\n\n    data = dataclasses.replace(\n        data,\n        _base_quaternion=W_Q_B,\n        _base_position=W_p_B,\n        _joint_positions=s,\n        _joint_velocities=ṡ,\n        _base_linear_velocity=W_v_B[0:3],\n        _base_angular_velocity=W_ω_WB,\n        contact_state=integrated_contact_state,\n    )\n\n    # Recompute kinematic caches for the new state.\n    data = data.replace(model=model)\n\n    return data\n\n\ndef rk4_integration(\n    model: js.model.JaxSimModel,\n    data: JaxSimModelData,\n    link_forces: jtp.Vector,\n    joint_torques: jtp.Vector,\n) -> JaxSimModelData:\n    \"\"\"Integrate the system state using the Runge-Kutta 4 method.\"\"\"\n\n    dt = model.time_step\n\n    def f(x) -> dict[str, jtp.Matrix]:\n\n        with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):\n\n            data_ti = data.replace(model=model, **x)\n\n            return js.ode.system_dynamics(\n                model=model,\n                data=data_ti,\n                link_forces=link_forces,\n                joint_torques=joint_torques,\n            )\n\n    base_quaternion_norm = jaxsim.math.safe_norm(data._base_quaternion, axis=-1)\n    base_quaternion = data._base_quaternion / jnp.where(\n        base_quaternion_norm == 0, 1.0, base_quaternion_norm\n    )\n\n    x_t0 = dict(\n        base_position=data._base_position,\n        base_quaternion=base_quaternion,\n        joint_positions=data._joint_positions,\n        base_linear_velocity=data._base_linear_velocity,\n        base_angular_velocity=data._base_angular_velocity,\n        joint_velocities=data._joint_velocities,\n        contact_state=data.contact_state,\n    )\n\n    euler_mid = lambda x, dxdt: x + (0.5 * dt) * dxdt\n    euler_fin = lambda x, dxdt: x + dt * dxdt\n\n    k1 = f(x_t0)\n    k2 = f(jax.tree.map(euler_mid, x_t0, k1))\n    k3 = f(jax.tree.map(euler_mid, x_t0, k2))\n    k4 = f(jax.tree.map(euler_fin, x_t0, k3))\n\n    # Average the slopes and compute the RK4 state derivative.\n    average = lambda k1, k2, k3, k4: (k1 + 2 * k2 + 2 * k3 + k4) / 6\n\n    dxdt = jax.tree.map(average, k1, k2, k3, k4)\n\n    # Integrate the dynamics\n    x_tf = jax.tree.map(euler_fin, x_t0, dxdt)\n\n    data_tf = dataclasses.replace(\n        data,\n        _base_position=x_tf[\"base_position\"],\n        _base_quaternion=x_tf[\"base_quaternion\"],\n        _joint_positions=x_tf[\"joint_positions\"],\n        _base_linear_velocity=x_tf[\"base_linear_velocity\"],\n        _base_angular_velocity=x_tf[\"base_angular_velocity\"],\n        _joint_velocities=x_tf[\"joint_velocities\"],\n        contact_state=x_tf[\"contact_state\"],\n    )\n\n    return data_tf.replace(model=model)\n\n\ndef rk4fast_integration(\n    model: js.model.JaxSimModel,\n    data: JaxSimModelData,\n    link_forces: jtp.Vector,\n    joint_torques: jtp.Vector,\n) -> JaxSimModelData:\n    \"\"\"\n    Integrate the system state using the Runge-Kutta 4 fast method.\n\n    Note:\n        This method is a faster version of the RK4 method, but it may not be as accurate.\n        It computes the contact forces only once at the beginning of the integration step.\n    \"\"\"\n\n    dt = model.time_step\n\n    if len(model.kin_dyn_parameters.contact_parameters.body) > 0:\n\n        # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact\n        # with the terrain.\n        W_f_L_terrain, contact_state_derivative = js.contact.link_contact_forces(\n            model=model,\n            data=data,\n            link_forces=link_forces,\n            joint_torques=joint_torques,\n        )\n\n    W_f_L_total = link_forces + W_f_L_terrain\n\n    # Update the contact state data. This is necessary only for the contact models\n    # that require propagation and integration of contact state.\n    contact_state = model.contact_model.update_contact_state(contact_state_derivative)\n\n    def f(x) -> dict[str, jtp.Matrix]:\n\n        with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):\n\n            data_ti = data.replace(model=model, **x)\n\n            W_v̇_WB, s̈ = js.model.forward_dynamics_aba(\n                model=model,\n                data=data_ti,\n                joint_forces=joint_torques,\n                link_forces=W_f_L_total,\n            )\n\n            W_ṗ_B, W_Q̇_B, ṡ = js.ode.system_position_dynamics(\n                data=data,\n                baumgarte_quaternion_regularization=1.0,\n            )\n\n        return dict(\n            base_position=W_ṗ_B,\n            base_quaternion=W_Q̇_B,\n            joint_positions=ṡ,\n            base_linear_velocity=W_v̇_WB[0:3],\n            base_angular_velocity=W_v̇_WB[3:6],\n            joint_velocities=s̈,\n            # The contact state is not updated here, as it is assumed to be constant.\n            contact_state=data_ti.contact_state,\n        )\n\n    base_quaternion_norm = jaxsim.math.safe_norm(data._base_quaternion, axis=-1)\n    base_quaternion = data._base_quaternion / jnp.where(\n        base_quaternion_norm == 0, 1.0, base_quaternion_norm\n    )\n\n    x_t0 = dict(\n        base_position=data._base_position,\n        base_quaternion=base_quaternion,\n        joint_positions=data._joint_positions,\n        base_linear_velocity=data._base_linear_velocity,\n        base_angular_velocity=data._base_angular_velocity,\n        joint_velocities=data._joint_velocities,\n        contact_state=contact_state,\n    )\n\n    euler_mid = lambda x, dxdt: x + (0.5 * dt) * dxdt\n    euler_fin = lambda x, dxdt: x + dt * dxdt\n\n    k1 = f(x_t0)\n    k2 = f(jax.tree.map(euler_mid, x_t0, k1))\n    k3 = f(jax.tree.map(euler_mid, x_t0, k2))\n    k4 = f(jax.tree.map(euler_fin, x_t0, k3))\n\n    # Average the slopes and compute the RK4 state derivative.\n    average = lambda k1, k2, k3, k4: (k1 + 2 * k2 + 2 * k3 + k4) / 6\n\n    dxdt = jax.tree.map(average, k1, k2, k3, k4)\n\n    # Integrate the dynamics\n    x_tf = jax.tree.map(euler_fin, x_t0, dxdt)\n\n    data_tf = dataclasses.replace(\n        data,\n        _base_position=x_tf[\"base_position\"],\n        _base_quaternion=x_tf[\"base_quaternion\"],\n        _joint_positions=x_tf[\"joint_positions\"],\n        _base_linear_velocity=x_tf[\"base_linear_velocity\"],\n        _base_angular_velocity=x_tf[\"base_angular_velocity\"],\n        _joint_velocities=x_tf[\"joint_velocities\"],\n        contact_state=x_tf[\"contact_state\"],\n    )\n\n    return data_tf.replace(model=model)\n\n\n_INTEGRATORS_MAP: dict[\n    js.model.IntegratorType, Callable[..., js.data.JaxSimModelData]\n] = {\n    js.model.IntegratorType.SemiImplicitEuler: semi_implicit_euler_integration,\n    js.model.IntegratorType.RungeKutta4: rk4_integration,\n    js.model.IntegratorType.RungeKutta4Fast: rk4fast_integration,\n}\n"
  },
  {
    "path": "src/jaxsim/api/joint.py",
    "content": "import functools\nfrom collections.abc import Sequence\n\nimport jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim import exceptions\n\n# =======================\n# Index-related functions\n# =======================\n\n\n@functools.partial(jax.jit, static_argnames=\"joint_name\")\n@js.common.named_scope\ndef name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int:\n    \"\"\"\n    Convert the name of a joint to its index.\n\n    Args:\n        model: The model to consider.\n        joint_name: The name of the joint.\n\n    Returns:\n        The index of the joint.\n    \"\"\"\n\n    if joint_name not in model.joint_names():\n        raise ValueError(f\"Joint '{joint_name}' not found in the model.\")\n\n    # Note: the index of the joint for RBDAs starts from 1, but the index for\n    # accessing the right element starts from 0. Therefore, there is a -1.\n    return (\n        jnp.array(\n            model.kin_dyn_parameters.joint_model.joint_names.index(joint_name) - 1\n        )\n        .astype(int)\n        .squeeze()\n    )\n\n\ndef idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:\n    \"\"\"\n    Convert the index of a joint to its name.\n\n    Args:\n        model: The model to consider.\n        joint_index: The index of the joint.\n\n    Returns:\n        The name of the joint.\n    \"\"\"\n\n    exceptions.raise_value_error_if(\n        condition=joint_index < 0,\n        msg=\"Invalid joint index '{idx}'\",\n        idx=joint_index,\n    )\n\n    return model.kin_dyn_parameters.joint_model.joint_names[joint_index + 1]\n\n\n@functools.partial(jax.jit, static_argnames=\"joint_names\")\n@js.common.named_scope\ndef names_to_idxs(\n    model: js.model.JaxSimModel, *, joint_names: Sequence[str]\n) -> jax.Array:\n    \"\"\"\n    Convert a sequence of joint names to their corresponding indices.\n\n    Args:\n        model: The model to consider.\n        joint_names: The names of the joints.\n\n    Returns:\n        The indices of the joints.\n    \"\"\"\n\n    return jnp.array(\n        [name_to_idx(model=model, joint_name=name) for name in joint_names],\n    ).astype(int)\n\n\ndef idxs_to_names(\n    model: js.model.JaxSimModel,\n    *,\n    joint_indices: Sequence[jtp.IntLike] | jtp.VectorLike,\n) -> tuple[str, ...]:\n    \"\"\"\n    Convert a sequence of joint indices to their corresponding names.\n\n    Args:\n        model: The model to consider.\n        joint_indices: The indices of the joints.\n\n    Returns:\n        The names of the joints.\n    \"\"\"\n\n    return tuple(idx_to_name(model=model, joint_index=idx) for idx in joint_indices)\n\n\n# ============\n# Joint limits\n# ============\n\n\n@jax.jit\ndef position_limit(\n    model: js.model.JaxSimModel, *, joint_index: jtp.IntLike\n) -> tuple[jtp.Float, jtp.Float]:\n    \"\"\"\n    Get the position limits of a joint.\n\n    Args:\n        model: The model to consider.\n        joint_index: The index of the joint.\n\n    Returns:\n        The position limits of the joint.\n    \"\"\"\n\n    if model.number_of_joints() == 0:\n        return jnp.empty(0).astype(float), jnp.empty(0).astype(float)\n\n    exceptions.raise_value_error_if(\n        condition=jnp.array(\n            [joint_index < 0, joint_index >= model.number_of_joints()]\n        ).any(),\n        msg=\"Invalid joint index '{idx}'\",\n        idx=joint_index,\n    )\n\n    s_min = jnp.atleast_1d(\n        model.kin_dyn_parameters.joint_parameters.position_limits_min\n    )[joint_index]\n    s_max = jnp.atleast_1d(\n        model.kin_dyn_parameters.joint_parameters.position_limits_max\n    )[joint_index]\n\n    return s_min.astype(float), s_max.astype(float)\n\n\n@functools.partial(jax.jit, static_argnames=[\"joint_names\"])\n@js.common.named_scope\ndef position_limits(\n    model: js.model.JaxSimModel, *, joint_names: Sequence[str] | None = None\n) -> tuple[jtp.Vector, jtp.Vector]:\n    \"\"\"\n    Get the position limits of a list of joint.\n\n    Args:\n        model: The model to consider.\n        joint_names: The names of the joints.\n\n    Returns:\n        The position limits of the joints.\n    \"\"\"\n\n    joint_idxs = (\n        names_to_idxs(joint_names=joint_names, model=model)\n        if joint_names is not None\n        else jnp.arange(model.number_of_joints())\n    )\n\n    if len(joint_idxs) == 0:\n        return jnp.empty(0).astype(float), jnp.empty(0).astype(float)\n\n    s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min[joint_idxs]\n    s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max[joint_idxs]\n\n    return s_min.astype(float), s_max.astype(float)\n\n\n# ======================\n# Random data generation\n# ======================\n\n\n@functools.partial(jax.jit, static_argnames=[\"joint_names\"])\n@js.common.named_scope\ndef random_joint_positions(\n    model: js.model.JaxSimModel,\n    *,\n    joint_names: Sequence[str] | None = None,\n    key: jax.Array | None = None,\n) -> jtp.Vector:\n    \"\"\"\n    Generate random joint positions.\n\n    Args:\n        model: The model to consider.\n        joint_names: The names of the considered joints (all if None).\n        key: The random key (initialized from seed 0 if None).\n\n    Note:\n        If the joint range or revolute joints is larger than 2π, their joint positions\n        will be sampled from an interval of size 2π.\n\n    Returns:\n        The random joint positions.\n    \"\"\"\n\n    # Consider the key corresponding to a zero seed if it was not passed.\n    key = key if key is not None else jax.random.PRNGKey(seed=0)\n\n    # Get the joint limits parsed from the model description.\n    s_min, s_max = position_limits(model=model, joint_names=joint_names)\n\n    # Get the joint indices.\n    # Note that it will trigger an exception if the given `joint_names` are not valid.\n    joint_names = joint_names if joint_names is not None else model.joint_names()\n    joint_indices = (\n        names_to_idxs(model=model, joint_names=joint_names)\n        if joint_names is not None\n        else jnp.arange(model.number_of_joints())\n    )\n\n    from jaxsim.parsers.descriptions.joint import JointType\n\n    # Filter for revolute joints.\n    is_revolute = jnp.where(\n        jnp.array(model.kin_dyn_parameters.joint_model.joint_types[1:])[joint_indices]\n        == JointType.Revolute,\n        True,\n        False,\n    )\n\n    # Shorthand for π.\n    π = jnp.pi\n\n    # Filter for revolute with full range (or continuous).\n    is_revolute_full_range = jnp.logical_and(is_revolute, s_max - s_min >= 2 * π)\n\n    # Clip the lower limit to -π if the joint range is larger than [-π, π].\n    s_min = jnp.where(\n        jnp.logical_and(\n            is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)\n        ),\n        -π,\n        s_min,\n    )\n\n    # Clip the upper limit to +π if the joint range is larger than [-π, π].\n    s_max = jnp.where(\n        jnp.logical_and(\n            is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)\n        ),\n        π,\n        s_max,\n    )\n\n    # Shift the lower limit if the upper limit is smaller than +π.\n    s_min = jnp.where(\n        jnp.logical_and(is_revolute_full_range, s_max < π),\n        s_max - 2 * π,\n        s_min,\n    )\n\n    # Shift the upper limit if the lower limit is larger than -π.\n    s_max = jnp.where(\n        jnp.logical_and(is_revolute_full_range, s_min > -π),\n        s_min + 2 * π,\n        s_max,\n    )\n\n    # Sample the joint positions.\n    s_random = jax.random.uniform(\n        minval=s_min,\n        maxval=s_max,\n        key=key,\n        shape=s_min.shape,\n    )\n\n    return s_random\n"
  },
  {
    "path": "src/jaxsim/api/kin_dyn_parameters.py",
    "content": "from __future__ import annotations\n\nimport dataclasses\nfrom itertools import starmap\nfrom typing import ClassVar\n\nimport jax.lax\nimport jax.numpy as jnp\nimport jax_dataclasses\nimport numpy as np\nimport numpy.typing as npt\nfrom jax_dataclasses import Static\n\nimport jaxsim\nimport jaxsim.typing as jtp\nfrom jaxsim.math import Inertia, JointModel, supported_joint_motion\nfrom jaxsim.math.adjoint import Adjoint\nfrom jaxsim.parsers.descriptions import JointDescription, JointType, ModelDescription\nfrom jaxsim.utils import HashedNumpyArray, JaxsimDataclass\n\n\n@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)\nclass KinDynParameters(JaxsimDataclass):\n    r\"\"\"\n    Class storing the kinematic and dynamic parameters of a model.\n\n    Attributes:\n        link_names: The names of the links.\n        parent_array: The parent array :math:`\\lambda(i)` of the model.\n        support_body_array_bool:\n            The boolean support parent array :math:`\\kappa_{b}(i)` of the model.\n        link_parameters: The parameters of the links.\n        frame_parameters: The parameters of the frames.\n        contact_parameters: The parameters of the collidable points.\n        joint_model: The joint model of the model.\n        joint_parameters: The parameters of the joints.\n        hw_link_metadata: The hardware parameters of the model links.\n        constraints: The kinematic constraints of the model. They can be used only with Relaxed-Rigid contact model.\n    \"\"\"\n\n    # Static\n    link_names: Static[tuple[str]]\n    _parent_array: Static[HashedNumpyArray]\n    _support_body_array_bool: Static[HashedNumpyArray]\n    _motion_subspaces: Static[HashedNumpyArray]\n\n    # Tree level structure for parallel algorithms.\n    # level_nodes: (n_levels, max_width) array of link indices at each depth level,\n    #   padded with 0 for levels with fewer nodes than max_width.\n    # level_mask: (n_levels, max_width) boolean mask, True for real nodes.\n    _level_nodes: Static[HashedNumpyArray]\n    _level_mask: Static[HashedNumpyArray]\n\n    # Links\n    link_parameters: LinkParameters\n\n    # Contacts\n    contact_parameters: ContactParameters\n\n    # Frames\n    frame_parameters: FrameParameters\n\n    # Joints\n    joint_model: JointModel\n    joint_parameters: JointParameters | None\n\n    # Model hardware parameters\n    hw_link_metadata: HwLinkMetadata | None = dataclasses.field(default=None)\n\n    # Kinematic constraints\n    constraints: ConstraintMap | None = dataclasses.field(default=None)\n\n    @property\n    def motion_subspaces(self) -> jtp.Matrix:\n        r\"\"\"\n        Return the motion subspaces :math:`\\mathbf{S}(s)` of the joints.\n        \"\"\"\n        return self._motion_subspaces.get()\n\n    @property\n    def parent_array(self) -> jtp.Vector:\n        r\"\"\"\n        Return the parent array :math:`\\lambda(i)` of the model.\n        \"\"\"\n        return self._parent_array.get()\n\n    @property\n    def support_body_array_bool(self) -> jtp.Matrix:\n        r\"\"\"\n        Return the boolean support parent array :math:`\\kappa_{b}(i)` of the model.\n        \"\"\"\n        return self._support_body_array_bool.get()\n\n    @property\n    def level_nodes(self) -> jtp.Matrix:\n        r\"\"\"\n        Return the tree level nodes array of shape ``(n_levels, max_width)``.\n        Each row contains the link indices at the corresponding depth level,\n        padded with 0 for levels with fewer nodes than ``max_width``.\n        \"\"\"\n        return self._level_nodes.get()\n\n    @property\n    def level_mask(self) -> jtp.Matrix:\n        r\"\"\"\n        Return the tree level mask of shape ``(n_levels, max_width)``.\n        Each entry is ``True`` for real nodes and ``False`` for padding.\n        \"\"\"\n        return self._level_mask.get()\n\n    @staticmethod\n    def _compute_tree_levels(\n        parent_array: np.ndarray,\n    ) -> tuple[np.ndarray, np.ndarray]:\n        \"\"\"\n        Compute the tree level decomposition from a parent array.\n\n        Args:\n            parent_array: Array of shape ``(n,)`` where ``parent_array[i]``\n                is the parent of link ``i``. ``parent_array[0] == -1`` for the root.\n\n        Returns:\n            A tuple ``(level_nodes, level_mask)`` where:\n            - ``level_nodes`` has shape ``(n_levels, max_width)`` with link\n              indices at each depth level (padded with 0).\n            - ``level_mask`` has shape ``(n_levels, max_width)`` with ``True``\n              for real nodes.\n        \"\"\"\n        import numpy as np\n\n        n = len(parent_array)\n\n        # Compute depth of each node.\n        depth = np.zeros(n, dtype=int)\n        for i in range(1, n):\n            depth[i] = depth[parent_array[i]] + 1\n\n        max_depth = int(depth.max()) if n > 0 else 0\n        n_levels = max_depth + 1\n\n        # Group nodes by depth level.\n        levels: list[list[int]] = [[] for _ in range(n_levels)]\n        for i in range(n):\n            levels[depth[i]].append(i)\n\n        max_width = max(len(lev) for lev in levels) if levels else 1\n\n        # Build padded arrays.\n        level_nodes = np.zeros((n_levels, max_width), dtype=int)\n        level_mask = np.zeros((n_levels, max_width), dtype=bool)\n        for d, lev in enumerate(levels):\n            for j, node_idx in enumerate(lev):\n                level_nodes[d, j] = node_idx\n                level_mask[d, j] = True\n\n        return level_nodes, level_mask\n\n    @staticmethod\n    def build(\n        model_description: ModelDescription, constraints: ConstraintMap | None\n    ) -> KinDynParameters:\n        \"\"\"\n        Construct the kinematic and dynamic parameters of the model.\n\n        Args:\n            model_description: The parsed model description to consider.\n            constraints: An object of type ConstraintMap specifying the kinematic constraint of the model.\n\n        Returns:\n            The kinematic and dynamic parameters of the model.\n\n        Note:\n            This class is meant to ease the management of parametric models in\n            an automatic differentiation context.\n        \"\"\"\n\n        # Extract the links ordered by their index.\n        # The link index corresponds to the body index ∈ [0, num_bodies - 1].\n        ordered_links = sorted(\n            list(model_description.links_dict.values()),\n            key=lambda l: l.index,\n        )\n\n        # Extract the joints ordered by their index.\n        # The joint index matches the index of its child link, therefore it starts\n        # from 1. Keep this in mind since this 1-indexing might introduce bugs.\n        ordered_joints = sorted(\n            list(model_description.joints_dict.values()),\n            key=lambda j: j.index,\n        )\n\n        # ================\n        # Links properties\n        # ================\n\n        # Create a list of link parameters objects.\n        link_parameters_list = [\n            LinkParameters.build_from_spatial_inertia(index=link.index, M=link.inertia)\n            for link in ordered_links\n        ]\n\n        # Create a vectorized object of link parameters.\n        link_parameters = jax.tree.map(lambda *l: jnp.stack(l), *link_parameters_list)\n\n        # =================\n        # Joints properties\n        # =================\n\n        # Create a list of joint parameters objects.\n        joint_parameters_list = [\n            JointParameters.build_from_joint_description(joint_description=joint)\n            for joint in ordered_joints\n        ]\n\n        # Create a vectorized object of joint parameters.\n        joint_parameters = (\n            jax.tree.map(lambda *l: jnp.stack(l), *joint_parameters_list)\n            if ordered_joints\n            else JointParameters(\n                index=jnp.array([], dtype=int),\n                friction_static=jnp.array([], dtype=float),\n                friction_viscous=jnp.array([], dtype=float),\n                position_limits_min=jnp.array([], dtype=float),\n                position_limits_max=jnp.array([], dtype=float),\n                position_limit_spring=jnp.array([], dtype=float),\n                position_limit_damper=jnp.array([], dtype=float),\n            )\n        )\n\n        # Create an object that defines the joint model (parent-to-child transforms).\n        joint_model = JointModel.build(description=model_description)\n\n        # ===================\n        # Contacts properties\n        # ===================\n\n        # Create the object storing the parameters of collidable points.\n        # Note that, contrarily to LinkParameters and JointsParameters, this object\n        # is not created with vmap. This is because the \"body\" attribute of the object\n        # must be Static for JIT-related reasons, and tree_map would not consider it\n        # as a leaf.\n        contact_parameters = ContactParameters.build_from(\n            model_description=model_description\n        )\n\n        # =================\n        # Frames properties\n        # =================\n\n        # Create the object storing the parameters of frames.\n        # Note that, contrarily to LinkParameters and JointsParameters, this object\n        # is not created with vmap. This is because the \"name\" attribute of the object\n        # must be Static for JIT-related reasons, and tree_map would not consider it\n        # as a leaf.\n        frame_parameters = FrameParameters.build_from(\n            model_description=model_description\n        )\n\n        # ===============\n        # Tree properties\n        # ===============\n\n        # Build the parent array λ(i) of the model.\n        # Note: the parent of the base link is not set since it's not defined.\n        parent_array_dict = {\n            link.index: model_description.links_dict[link.parent_name].index\n            for link in ordered_links\n            if link.parent_name is not None\n        }\n        parent_array = jnp.array([-1, *list(parent_array_dict.values())], dtype=int)\n\n        # Instead of building the support parent array κ(i) for each link of the model,\n        # that has a variable length depending on the number of links connecting the\n        # root to the i-th link, we build the corresponding boolean version.\n        # Given a link index i, the boolean support parent array κb(i) is an array\n        # with the same number of elements of λ(i) having the i-th element set to True\n        # if the i-th link is in the support parent array κ(i), False otherwise.\n        # We store the boolean κb(i) as static attribute of the PyTree so that\n        # algorithms that need to access it can be jit-compiled.\n        def κb(link_index: jtp.IntLike) -> jtp.Vector:\n            κb = jnp.zeros(len(ordered_links), dtype=bool)\n\n            carry0 = κb, link_index\n\n            def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]:\n                κb, active_link_index = carry\n\n                κb, active_link_index = jax.lax.cond(\n                    pred=(i == active_link_index),\n                    false_fun=lambda: (κb, active_link_index),\n                    true_fun=lambda: (\n                        κb.at[active_link_index].set(True),\n                        parent_array[active_link_index],\n                    ),\n                )\n\n                return (κb, active_link_index), None\n\n            (κb, _), _ = jax.lax.scan(\n                f=scan_body,\n                init=carry0,\n                xs=jnp.flip(jnp.arange(start=0, stop=len(ordered_links))),\n            )\n\n            return κb\n\n        support_body_array_bool = jax.vmap(κb)(\n            jnp.arange(start=0, stop=len(ordered_links))\n        )\n\n        def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike:\n            S = {\n                JointType.Fixed: np.zeros(shape=(6, 1)),\n                JointType.Revolute: np.vstack(np.hstack([np.zeros(3), axis.axis])),\n                JointType.Prismatic: np.vstack(np.hstack([axis.axis, np.zeros(3)])),\n            }\n\n            return S[joint_type]\n\n        S_J = jnp.array(\n            list(\n                starmap(\n                    motion_subspace,\n                    zip(\n                        joint_model.joint_types[1:], joint_model.joint_axis, strict=True\n                    ),\n                )\n            )\n            if len(joint_model.joint_axis) != 0\n            else jnp.empty((0, 6, 1))\n        )\n\n        motion_subspaces = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J])\n\n        # ====================\n        # Tree level structure\n        # ====================\n\n        parent_array_np = np.array([-1, *list(parent_array_dict.values())], dtype=int)\n        level_nodes, level_mask = KinDynParameters._compute_tree_levels(parent_array_np)\n\n        # ===========\n        # Constraints\n        # ===========\n\n        constraints = ConstraintMap() if constraints is None else constraints\n\n        # =================================\n        # Build and return KinDynParameters\n        # =================================\n\n        return KinDynParameters(\n            link_names=tuple(l.name for l in ordered_links),\n            _parent_array=HashedNumpyArray(array=parent_array),\n            _support_body_array_bool=HashedNumpyArray(array=support_body_array_bool),\n            _motion_subspaces=HashedNumpyArray(array=motion_subspaces),\n            _level_nodes=HashedNumpyArray(array=level_nodes),\n            _level_mask=HashedNumpyArray(array=level_mask),\n            link_parameters=link_parameters,\n            joint_model=joint_model,\n            joint_parameters=joint_parameters,\n            contact_parameters=contact_parameters,\n            frame_parameters=frame_parameters,\n            constraints=constraints,\n        )\n\n    def __eq__(self, other: KinDynParameters) -> bool:\n        if not isinstance(other, KinDynParameters):\n            return False\n\n        return hash(self) == hash(other)\n\n    def __hash__(self) -> int:\n        return hash(\n            (\n                hash(self.number_of_links()),\n                hash(self.number_of_joints()),\n                hash(self.frame_parameters.name),\n                hash(self.frame_parameters.body),\n                hash(self._parent_array),\n                hash(self._support_body_array_bool),\n            )\n        )\n\n    # =============================\n    # Helpers to extract parameters\n    # =============================\n\n    def number_of_links(self) -> int:\n        \"\"\"\n        Return the number of links of the model.\n\n        Returns:\n            The number of links of the model.\n        \"\"\"\n\n        return len(self.link_names)\n\n    def number_of_joints(self) -> int:\n        \"\"\"\n        Return the number of joints of the model.\n\n        Returns:\n            The number of joints of the model.\n        \"\"\"\n\n        return len(self.joint_model.joint_names) - 1\n\n    def number_of_frames(self) -> int:\n        \"\"\"\n        Return the number of frames of the model.\n\n        Returns:\n            The number of frames of the model.\n        \"\"\"\n\n        return len(self.frame_parameters.name)\n\n    def support_body_array(self, link_index: jtp.IntLike) -> jtp.Vector:\n        r\"\"\"\n        Return the support parent array :math:`\\kappa(i)` of a link.\n\n        Args:\n            link_index: The index of the link.\n\n        Returns:\n            The support parent array :math:`\\kappa(i)` of the link.\n\n        Note:\n            This method returns a variable-length vector. In jit-compiled functions,\n            it's better to use the (static) boolean version `support_body_array_bool`.\n        \"\"\"\n\n        return jnp.array(\n            jnp.where(self.support_body_array_bool[link_index])[0], dtype=int\n        )\n\n    # ========================\n    # Quantities used by RBDAs\n    # ========================\n\n    @jax.jit\n    def links_spatial_inertia(self) -> jtp.Array:\n        \"\"\"\n        Return the spatial inertia of all links of the model.\n\n        Returns:\n            The spatial inertia of all links of the model.\n        \"\"\"\n\n        return jax.vmap(LinkParameters.spatial_inertia)(self.link_parameters)\n\n    @jax.jit\n    def tree_transforms(self) -> jtp.Array:\n        r\"\"\"\n        Return the tree transforms of the model.\n\n        Returns:\n            The transforms\n            :math:`{}^{\\text{pre}(i)} H_{\\lambda(i)}`\n            of all joints of the model.\n        \"\"\"\n\n        pre_Xi_λ = jax.vmap(\n            lambda i: self.joint_model.parent_H_predecessor(joint_index=i)\n            .inverse()\n            .adjoint()\n        )(jnp.arange(1, self.number_of_joints() + 1))\n\n        return jnp.vstack(\n            [\n                jnp.zeros(shape=(1, 6, 6), dtype=float),\n                pre_Xi_λ,\n            ]\n        )\n\n    @jax.jit\n    def joint_transforms(\n        self, joint_positions: jtp.VectorLike, base_transform: jtp.MatrixLike\n    ) -> jtp.Array:\n        r\"\"\"\n        Return the transforms of the joints.\n\n        Args:\n            joint_positions: The joint positions.\n            base_transform: The homogeneous matrix defining the base pose.\n\n        Returns:\n            The stacked transforms\n            :math:`{}^{i} \\mathbf{H}_{\\lambda(i)}(s)`\n            of each joint.\n        \"\"\"\n\n        # Rename the base transform.\n        W_H_B = base_transform\n\n        # Extract the parent-to-predecessor fixed transforms of the joints.\n        λ_H_pre = jnp.vstack(\n            [\n                jnp.eye(4)[jnp.newaxis],\n                self.joint_model.λ_H_pre[1 : 1 + self.number_of_joints()],\n            ]\n        )\n        if self.number_of_joints() == 0:\n            pre_H_suc_J = jnp.empty((0, 4, 4))\n        else:\n            pre_H_suc_J = jax.vmap(supported_joint_motion)(\n                joint_types=jnp.array(self.joint_model.joint_types[1:]).astype(int),\n                joint_positions=jnp.array(joint_positions),\n                joint_axes=jnp.array([j.axis for j in self.joint_model.joint_axis]),\n            )\n\n        # Extract the transforms and motion subspaces of the joints.\n        # We stack the base transform W_H_B at index 0, and a dummy motion subspace\n        # for either the fixed or free-floating joint connecting the world to the base.\n        pre_H_suc = jnp.vstack([W_H_B[jnp.newaxis, ...], pre_H_suc_J])\n\n        # Extract the successor-to-child fixed transforms.\n        # Note that here we include also the index 0 since suc_H_child[0] stores the\n        # optional pose of the base link w.r.t. the root frame of the model.\n        # This is supported by SDF when the base link <pose> element is defined.\n        suc_H_i = self.joint_model.suc_H_i[jnp.arange(0, 1 + self.number_of_joints())]\n\n        # Compute the overall transforms from the parent to the child of each joint by\n        # composing all the components of our joint model.\n        i_X_λ = jax.vmap(\n            lambda λ_Hi_pre, pre_Hi_suc, suc_Hi_i: Adjoint.from_transform(\n                transform=λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i, inverse=True\n            )\n        )(λ_H_pre, pre_H_suc, suc_H_i)\n\n        return i_X_λ\n\n    # ============================\n    # Helpers to update parameters\n    # ============================\n\n    def set_link_mass(\n        self, link_index: jtp.IntLike, mass: jtp.FloatLike\n    ) -> KinDynParameters:\n        \"\"\"\n        Set the mass of a link.\n\n        Args:\n            link_index: The index of the link.\n            mass: The mass of the link.\n\n        Returns:\n            The updated kinematic and dynamic parameters of the model.\n        \"\"\"\n\n        link_parameters = self.link_parameters.replace(\n            mass=self.link_parameters.mass.at[link_index].set(mass)\n        )\n\n        return self.replace(link_parameters=link_parameters)\n\n    def set_link_inertia(\n        self, link_index: jtp.IntLike, inertia: jtp.MatrixLike\n    ) -> KinDynParameters:\n        r\"\"\"\n        Set the inertia tensor of a link.\n\n        Args:\n            link_index: The index of the link.\n            inertia: The :math:`3 \\times 3` inertia tensor of the link.\n\n        Returns:\n            The updated kinematic and dynamic parameters of the model.\n        \"\"\"\n\n        inertia_elements = LinkParameters.flatten_inertia_tensor(I=inertia)\n\n        link_parameters = self.link_parameters.replace(\n            inertia_elements=self.link_parameters.inertia_elements.at[link_index].set(\n                inertia_elements\n            )\n        )\n\n        return self.replace(link_parameters=link_parameters)\n\n\n@jax_dataclasses.pytree_dataclass\nclass JointParameters(JaxsimDataclass):\n    \"\"\"\n    Class storing the parameters of a joint.\n\n    Attributes:\n        index: The index of the joint.\n        friction_static: The static friction of the joint.\n        friction_viscous: The viscous friction of the joint.\n        position_limits_min: The lower position limit of the joint.\n        position_limits_max: The upper position limit of the joint.\n        position_limit_spring: The spring constant of the position limit.\n        position_limit_damper: The damper constant of the position limit.\n\n    Note:\n        This class is used inside KinDynParameters to store the vectorized set\n        of joint parameters.\n    \"\"\"\n\n    index: jtp.Int\n\n    friction_static: jtp.Float\n    friction_viscous: jtp.Float\n\n    position_limits_min: jtp.Float\n    position_limits_max: jtp.Float\n\n    position_limit_spring: jtp.Float\n    position_limit_damper: jtp.Float\n\n    @staticmethod\n    def build_from_joint_description(\n        joint_description: JointDescription,\n    ) -> JointParameters:\n        \"\"\"\n        Build a JointParameters object from a joint description.\n\n        Args:\n            joint_description: The joint description to consider.\n\n        Returns:\n            The JointParameters object.\n        \"\"\"\n\n        s_min = joint_description.position_limit[0]\n        s_max = joint_description.position_limit[1]\n\n        position_limits_min = jnp.minimum(s_min, s_max)\n        position_limits_max = jnp.maximum(s_min, s_max)\n\n        friction_static = jnp.array(joint_description.friction_static).squeeze()\n        friction_viscous = jnp.array(joint_description.friction_viscous).squeeze()\n\n        position_limit_spring = jnp.array(\n            joint_description.position_limit_spring\n        ).squeeze()\n\n        position_limit_damper = jnp.array(\n            joint_description.position_limit_damper\n        ).squeeze()\n\n        return JointParameters(\n            index=jnp.array(joint_description.index).squeeze().astype(int),\n            friction_static=friction_static.astype(float),\n            friction_viscous=friction_viscous.astype(float),\n            position_limits_min=position_limits_min.astype(float),\n            position_limits_max=position_limits_max.astype(float),\n            position_limit_spring=position_limit_spring.astype(float),\n            position_limit_damper=position_limit_damper.astype(float),\n        )\n\n\n@jax_dataclasses.pytree_dataclass\nclass LinkParameters(JaxsimDataclass):\n    r\"\"\"\n    Class storing the parameters of a link.\n\n    Attributes:\n        index: The index of the link.\n        mass: The mass of the link.\n        inertia_elements:\n            The unique elements of the :math:`3 \\times 3` inertia tensor of the link.\n        center_of_mass:\n            The translation :math:`{}^L \\mathbf{p}_{\\text{CoM}}` between the origin\n            of the link frame and the link's center of mass, expressed in the\n            coordinates of the link frame.\n\n    Note:\n        This class is used inside KinDynParameters to store the vectorized set\n        of link parameters.\n    \"\"\"\n\n    index: jtp.Int\n\n    mass: jtp.Float\n    center_of_mass: jtp.Vector\n    inertia_elements: jtp.Vector\n\n    @staticmethod\n    def build_from_spatial_inertia(index: jtp.IntLike, M: jtp.Matrix) -> LinkParameters:\n        r\"\"\"\n        Build a LinkParameters object from a :math:`6 \\times 6` spatial inertia matrix.\n\n        Args:\n            index: The index of the link.\n            M: The :math:`6 \\times 6` spatial inertia matrix of the link.\n\n        Returns:\n            The LinkParameters object.\n        \"\"\"\n\n        # Extract the link parameters from the 6D spatial inertia.\n        m, L_p_CoM, I_CoM = Inertia.to_params(M=M)\n\n        # Extract only the necessary elements of the inertia tensor.\n        inertia_elements = I_CoM[jnp.triu_indices(3)]\n\n        return LinkParameters(\n            index=jnp.array(index).squeeze().astype(int),\n            mass=jnp.array(m).squeeze().astype(float),\n            center_of_mass=jnp.atleast_1d(jnp.array(L_p_CoM).squeeze()).astype(float),\n            inertia_elements=jnp.atleast_1d(inertia_elements.squeeze()).astype(float),\n        )\n\n    @staticmethod\n    def build_from_inertial_parameters(\n        index: jtp.IntLike, m: jtp.FloatLike, I: jtp.MatrixLike, c: jtp.VectorLike\n    ) -> LinkParameters:\n        r\"\"\"\n        Build a LinkParameters object from the inertial parameters of a link.\n\n        Args:\n            index: The index of the link.\n            m: The mass of the link.\n            I: The :math:`3 \\times 3` inertia tensor of the link.\n            c: The translation between the link frame and the link's center of mass.\n\n        Returns:\n            The LinkParameters object.\n        \"\"\"\n\n        # Extract only the necessary elements of the inertia tensor.\n        inertia_elements = I[jnp.triu_indices(3)]\n\n        return LinkParameters(\n            index=jnp.array(index).squeeze().astype(int),\n            mass=jnp.array(m).squeeze().astype(float),\n            center_of_mass=jnp.atleast_1d(c.squeeze()).astype(float),\n            inertia_elements=jnp.atleast_1d(inertia_elements.squeeze()).astype(float),\n        )\n\n    @staticmethod\n    def build_from_flat_parameters(\n        index: jtp.IntLike, parameters: jtp.VectorLike\n    ) -> LinkParameters:\n        \"\"\"\n        Build a LinkParameters object from a flat vector of parameters.\n\n        Args:\n            index: The index of the link.\n            parameters: The flat vector of parameters.\n\n        Returns:\n            The LinkParameters object.\n        \"\"\"\n        index = jnp.array(index).squeeze().astype(int)\n\n        m = jnp.array(parameters[0]).squeeze().astype(float)\n        c = jnp.atleast_1d(parameters[1:4].squeeze()).astype(float)\n        inertia_elements = jnp.atleast_1d(parameters[4:].squeeze()).astype(float)\n\n        return LinkParameters(\n            index=index, mass=m, inertia_elements=inertia_elements, center_of_mass=c\n        )\n\n    @staticmethod\n    def flat_parameters(params: LinkParameters) -> jtp.Vector:\n        \"\"\"\n        Return the parameters of a link as a flat vector.\n\n        Args:\n            params: The link parameters.\n\n        Returns:\n            The parameters of the link as a flat vector.\n        \"\"\"\n\n        return (\n            jnp.hstack(\n                [\n                    params.mass,\n                    params.center_of_mass.squeeze(),\n                    params.inertia_elements,\n                ]\n            )\n            .squeeze()\n            .astype(float)\n        )\n\n    @staticmethod\n    def inertia_tensor(params: LinkParameters) -> jtp.Matrix:\n        r\"\"\"\n        Return the :math:`3 \\times 3` inertia tensor of a link.\n\n        Args:\n            params: The link parameters.\n\n        Returns:\n            The :math:`3 \\times 3` inertia tensor of the link.\n        \"\"\"\n\n        return LinkParameters.unflatten_inertia_tensor(\n            inertia_elements=params.inertia_elements\n        )\n\n    @staticmethod\n    def spatial_inertia(params: LinkParameters) -> jtp.Matrix:\n        r\"\"\"\n        Return the :math:`6 \\times 6` spatial inertia matrix of a link.\n\n        Args:\n            params: The link parameters.\n\n        Returns:\n            The :math:`6 \\times 6` spatial inertia matrix of the link.\n        \"\"\"\n\n        return Inertia.to_sixd(\n            mass=params.mass,\n            I=LinkParameters.inertia_tensor(params),\n            com=params.center_of_mass,\n        )\n\n    @staticmethod\n    def flatten_inertia_tensor(I: jtp.Matrix) -> jtp.Vector:\n        r\"\"\"\n        Flatten a :math:`3 \\times 3` inertia tensor into a vector of unique elements.\n\n        Args:\n            I: The :math:`3 \\times 3` inertia tensor.\n\n        Returns:\n            The vector of unique elements of the inertia tensor.\n        \"\"\"\n\n        return jnp.atleast_1d(I[jnp.triu_indices(3)].squeeze())\n\n    @staticmethod\n    def unflatten_inertia_tensor(inertia_elements: jtp.Vector) -> jtp.Matrix:\n        r\"\"\"\n        Unflatten a vector of unique elements into a :math:`3 \\times 3` inertia tensor.\n\n        Args:\n            inertia_elements: The vector of unique elements of the inertia tensor.\n\n        Returns:\n            The :math:`3 \\times 3` inertia tensor.\n        \"\"\"\n\n        I = jnp.zeros([3, 3]).at[jnp.triu_indices(3)].set(inertia_elements.squeeze())\n        return jnp.atleast_2d(jnp.where(I, I, I.T)).astype(float)\n\n\n@jax_dataclasses.pytree_dataclass\nclass ContactParameters(JaxsimDataclass):\n    \"\"\"\n    Class storing the contact parameters of a model.\n\n    Attributes:\n        body:\n            A tuple of integers representing, for each collidable point, the index of\n            the body (link) to which it is rigidly attached to.\n        point:\n            The translations between the link frame and the collidable point, expressed\n            in the coordinates of the parent link frame.\n        enabled:\n            A tuple of booleans representing, for each collidable point, whether it is\n            enabled or not in contact models.\n\n    Note:\n        Contrarily to LinkParameters and JointParameters, this class is not meant\n        to be created with vmap. This is because the `body` attribute must be `Static`.\n    \"\"\"\n\n    body: Static[tuple[int, ...]] = dataclasses.field(default_factory=tuple)\n\n    point: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.array([]))\n\n    enabled: Static[tuple[bool, ...]] = dataclasses.field(default_factory=tuple)\n\n    @property\n    def indices_of_enabled_collidable_points(self) -> npt.NDArray:\n        \"\"\"\n        Return the indices of the enabled collidable points.\n        \"\"\"\n        return np.where(np.array(self.enabled))[0]\n\n    @staticmethod\n    def build_from(model_description: ModelDescription) -> ContactParameters:\n        \"\"\"\n        Build a ContactParameters object from a model description.\n\n        Args:\n            model_description: The model description to consider.\n\n        Returns:\n            The ContactParameters object.\n        \"\"\"\n\n        if len(model_description.collision_shapes) == 0:\n            return ContactParameters()\n\n        # Get all the links so that we can take their updated index.\n        links_dict = {link.name: link for link in model_description}\n\n        # Get all the enabled collidable points of the model.\n        collidable_points = model_description.all_enabled_collidable_points()\n\n        # Extract the positions L_p_C of the collidable points w.r.t. the link frames\n        # they are rigidly attached to.\n        points = jnp.vstack([cp.position for cp in collidable_points])\n\n        # Extract the indices of the links to which the collidable points are rigidly\n        # attached to.\n        link_index_of_points = tuple(\n            links_dict[cp.parent_link.name].index for cp in collidable_points\n        )\n\n        # Build the ContactParameters object.\n        cp = ContactParameters(\n            point=points,\n            body=link_index_of_points,\n            enabled=tuple(True for _ in link_index_of_points),\n        )\n\n        assert cp.point.shape[1] == 3, cp.point.shape[1]\n        assert cp.point.shape[0] == len(cp.body), cp.point.shape[0]\n\n        return cp\n\n\n@jax_dataclasses.pytree_dataclass\nclass FrameParameters(JaxsimDataclass):\n    \"\"\"\n    Class storing the frame parameters of a model.\n\n    Attributes:\n        name: A tuple of strings defining the frame names.\n        body:\n            A vector of integers representing, for each frame, the index of\n            the body (link) to which it is rigidly attached to.\n        transform: The transforms of the frames w.r.t. their parent link.\n\n    Note:\n        Contrarily to LinkParameters and JointParameters, this class is not meant\n        to be created with vmap. This is because the `name` attribute must be `Static`.\n    \"\"\"\n\n    name: Static[tuple[str, ...]] = dataclasses.field(default_factory=tuple)\n\n    body: Static[tuple[int, ...]] = dataclasses.field(default_factory=tuple)\n\n    transform: jtp.Array = dataclasses.field(default_factory=lambda: jnp.array([]))\n\n    @staticmethod\n    def build_from(model_description: ModelDescription) -> FrameParameters:\n        \"\"\"\n        Build a FrameParameters object from a model description.\n\n        Args:\n            model_description: The model description to consider.\n\n        Returns:\n            The FrameParameters object.\n        \"\"\"\n\n        if len(model_description.frames) == 0:\n            return FrameParameters()\n\n        # Extract the frame names.\n        names = tuple(frame.name for frame in model_description.frames)\n\n        # For each frame, extract the index of the link to which it is attached to.\n        parent_link_index_of_frames = tuple(\n            model_description.links_dict[frame.parent_name].index\n            for frame in model_description.frames\n        )\n\n        # For each frame, extract the transform w.r.t. its parent link.\n        transforms = jnp.atleast_3d(\n            jnp.stack([frame.pose for frame in model_description.frames])\n        )\n\n        # Build the FrameParameters object.\n        fp = FrameParameters(\n            name=names,\n            transform=transforms.astype(float),\n            body=parent_link_index_of_frames,\n        )\n\n        assert fp.transform.shape[1:] == (4, 4), fp.transform.shape[1:]\n        assert fp.transform.shape[0] == len(fp.body), fp.transform.shape[0]\n\n        return fp\n\n\n@dataclasses.dataclass(frozen=True)\nclass LinkParametrizableShape:\n    \"\"\"\n    Enum-like class listing the supported shapes for HW parametrization.\n    \"\"\"\n\n    Unsupported: ClassVar[int] = -1\n    Box: ClassVar[int] = 0\n    Cylinder: ClassVar[int] = 1\n    Sphere: ClassVar[int] = 2\n    Mesh: ClassVar[int] = 3\n\n\n@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)\nclass HwLinkMetadata(JaxsimDataclass):\n    \"\"\"\n    Class storing the hardware parameters of a link.\n\n    Attributes:\n        link_shape: The shape of the link.\n            0 = box, 1 = cylinder, 2 = sphere, 3 = mesh, -1 = unsupported.\n        geometry: Shape parameters used by HW parametrization.\n            box: [lx,ly,lz], cylinder: [r,l,0], sphere: [r,0,0],\n            mesh: cumulative anisotropic scale factors [sx,sy,sz] (initialized to [1,1,1]).\n        density: The density of the link.\n        L_H_G: The homogeneous transformation matrix from the link frame to the CoM frame G.\n        L_H_vis: The homogeneous transformation matrix from the link frame to the visual frame.\n        L_H_pre_mask: The mask indicating the link's child joint indices.\n        L_H_pre: The homogeneous transforms for child joints.\n        mesh_moments: Precomputed volumetric moments for mesh shapes (n_links x 13).\n            Each row stores [V_ref, com_x, com_y, com_z, Σ_00..Σ_22] where V_ref is the\n            reference volume, com is the volumetric center of mass, and Σ is the\n            volumetric covariance matrix at the origin. Zero for non-mesh links.\n        mesh_vertices: The original centered mesh vertices (Nx3) for mesh shapes, None otherwise.\n        mesh_faces: The mesh triangle faces (Mx3 integer indices) for mesh shapes, None otherwise.\n        mesh_offset: The original mesh centroid offset (3D vector) for mesh shapes, None otherwise.\n        mesh_uri: The path to the mesh file for reference, None otherwise.\n    \"\"\"\n\n    link_shape: jtp.Vector\n    geometry: jtp.Vector\n    density: jtp.Float\n    L_H_G: jtp.Matrix\n    L_H_vis: jtp.Matrix\n    L_H_pre_mask: jtp.Vector\n    L_H_pre: jtp.Matrix\n    mesh_moments: jtp.Matrix\n    mesh_vertices: Static[tuple[HashedNumpyArray | None, ...] | None]\n    mesh_faces: Static[tuple[HashedNumpyArray | None, ...] | None]\n    mesh_offset: Static[tuple[HashedNumpyArray | None, ...] | None]\n    mesh_uri: Static[tuple[str | None, ...] | None]\n\n    @classmethod\n    def empty(cls) -> HwLinkMetadata:\n        \"\"\"Return hardware metadata representing the absence of links.\"\"\"\n        return cls(\n            link_shape=jnp.array([], dtype=int),\n            geometry=jnp.array([], dtype=float),\n            density=jnp.array([], dtype=float),\n            L_H_G=jnp.array([], dtype=float),\n            L_H_vis=jnp.array([], dtype=float),\n            L_H_pre_mask=jnp.array([], dtype=bool),\n            L_H_pre=jnp.array([], dtype=float),\n            mesh_moments=jnp.zeros((0, 13), dtype=float),\n            mesh_vertices=None,\n            mesh_faces=None,\n            mesh_offset=None,\n            mesh_uri=None,\n        )\n\n    @staticmethod\n    def compute_mesh_inertia(\n        vertices: jtp.Matrix, faces: jtp.Matrix, density: jtp.Float\n    ) -> tuple[jtp.Float, jtp.Vector, jtp.Matrix]:\n        \"\"\"\n        Compute mass, center of mass, and inertia tensor from mesh geometry.\n\n        Uses the divergence theorem to compute volumetric properties by integrating\n        over tetrahedra formed between the mesh surface and the origin.\n\n        Args:\n            vertices: Mesh vertices (Nx3) in the link frame, should be centered.\n            faces: Triangle face indices (Mx3), integer indices into vertices array.\n            density: Material density.\n\n        Returns:\n            A tuple containing the computed mass, the CoM position and the 3x3\n            inertia tensor at the CoM.\n        \"\"\"\n\n        # Extract triangles from vertices using face indices\n        triangles = vertices[faces.astype(int)]\n        A, B, C = triangles[:, 0], triangles[:, 1], triangles[:, 2]\n\n        # Compute signed volume of tetrahedra relative to origin\n        # vol = 1/6 * (A . (B x C))\n        tetrahedron_volumes = jnp.sum(A * jnp.cross(B, C), axis=1) / 6.0\n\n        total_signed_volume = jnp.sum(tetrahedron_volumes)\n\n        # Normalize the global winding sign so positive density yields non-negative mass.\n        orientation_sign = jnp.where(total_signed_volume < 0, -1.0, 1.0)\n        tetrahedron_volumes = tetrahedron_volumes * orientation_sign\n        total_volume = jnp.sum(tetrahedron_volumes)\n\n        eps = jnp.asarray(1e-12, dtype=total_volume.dtype)\n        is_valid_volume = jnp.abs(total_volume) > eps\n        safe_total_volume = jnp.where(is_valid_volume, total_volume, 1.0)\n        mass = jnp.where(is_valid_volume, total_volume * density, 0.0)\n\n        # Compute center of mass\n        tet_coms = (A + B + C) / 4.0\n        com_position = jnp.where(\n            is_valid_volume,\n            jnp.sum(tet_coms * tetrahedron_volumes[:, None], axis=0)\n            / safe_total_volume,\n            jnp.zeros(3, dtype=vertices.dtype),\n        )\n\n        # Compute inertia tensor with covariance approach\n        def compute_tetrahedron_covariance(a, b, c, vol):\n            s = a + b + c\n            return (vol / 20.0) * (\n                jnp.outer(a, a) + jnp.outer(b, b) + jnp.outer(c, c) + jnp.outer(s, s)\n            )\n\n        covariance_matrices = jax.vmap(compute_tetrahedron_covariance)(\n            A, B, C, tetrahedron_volumes\n        )\n        Σ_origin = jnp.sum(covariance_matrices, axis=0)\n\n        # Shift to CoM using parallel axis theorem\n        Σ_com = Σ_origin * density - mass * jnp.outer(com_position, com_position)\n\n        # Convert covariance to inertia tensor\n        I_com = jnp.trace(Σ_com) * jnp.eye(3, dtype=vertices.dtype) - Σ_com\n        I_com = jnp.where(\n            is_valid_volume, I_com, jnp.zeros((3, 3), dtype=vertices.dtype)\n        )\n\n        return mass, com_position, I_com\n\n    @staticmethod\n    def precompute_mesh_moments(vertices: np.ndarray, faces: np.ndarray) -> np.ndarray:\n        \"\"\"\n        Precompute volumetric moments from reference mesh geometry.\n\n        Computes the reference volume, center of mass, and volumetric covariance\n        matrix at the origin using numpy. These 13 scalars are sufficient to\n        analytically reconstruct mass and inertia under any anisotropic scaling,\n        avoiding the need to embed full mesh arrays in JIT-compiled programs.\n\n        Args:\n            vertices: Mesh vertices (Nx3), should be centered.\n            faces: Triangle face indices (Mx3).\n\n        Returns:\n            A 13-element array: [V_ref, com_x, com_y, com_z, Σ_00..Σ_22].\n        \"\"\"\n\n        triangles = vertices[faces.astype(int)]\n        A, B, C = triangles[:, 0], triangles[:, 1], triangles[:, 2]\n\n        volumes = np.sum(A * np.cross(B, C), axis=1) / 6.0\n\n        total_signed = np.sum(volumes)\n        sign = np.sign(total_signed) if abs(total_signed) > 1e-12 else 1.0\n        volumes = volumes * sign\n        V_ref = np.sum(volumes)\n\n        if abs(V_ref) < 1e-12:\n            return np.zeros(13, dtype=np.float64)\n\n        # Center of mass\n        com = np.sum(volumes[:, None] * (A + B + C) / 4.0, axis=0) / V_ref\n\n        # Volumetric covariance at origin (same formula as compute_mesh_inertia)\n        S = A + B + C\n        cov = (volumes[:, None, None] / 20.0) * (\n            A[:, :, None] * A[:, None, :]\n            + B[:, :, None] * B[:, None, :]\n            + C[:, :, None] * C[:, None, :]\n            + S[:, :, None] * S[:, None, :]\n        )\n        Sigma = np.sum(cov, axis=0)\n\n        return np.concatenate([[V_ref], com, Sigma.flatten()])\n\n    @staticmethod\n    def compute_mesh_inertia_from_moments(\n        moments: jtp.Vector, dims: jtp.Vector, density: jtp.Float\n    ) -> tuple[jtp.Float, jtp.Matrix]:\n        \"\"\"\n        Compute mass and inertia tensor from precomputed volumetric moments.\n\n        Uses analytical scaling laws to derive physical properties under\n        anisotropic scaling without requiring the full mesh geometry.\n\n        Under scaling S = diag(sx, sy, sz):\n          - V' = det(S) * V_ref\n          - com' = S @ com_ref\n          - Σ_origin' = det(S) * S @ Σ_ref @ S\n\n        Args:\n            moments: Precomputed moments array of length 13.\n            dims: Current anisotropic scale factors [sx, sy, sz].\n            density: Current material density.\n\n        Returns:\n            A tuple of (mass, inertia_at_com).\n        \"\"\"\n\n        V_ref = moments[0]\n        com_ref = moments[1:4]\n        Sigma_ref = moments[4:13].reshape(3, 3)\n\n        det_s = dims[0] * dims[1] * dims[2]\n        S = jnp.diag(dims)\n\n        mass = density * V_ref * det_s\n        com = dims * com_ref\n\n        Sigma_scaled = det_s * (S @ Sigma_ref @ S)\n        Sigma_com = density * Sigma_scaled - mass * jnp.outer(com, com)\n        I_com = jnp.trace(Sigma_com) * jnp.eye(3) - Sigma_com\n\n        is_valid = V_ref > 1e-12\n        mass = jnp.where(is_valid, mass, 0.0)\n        I_com = jnp.where(is_valid, I_com, jnp.zeros((3, 3)))\n\n        return mass, I_com\n\n    @staticmethod\n    def compute_mass_and_inertia(\n        hw_link_metadata: HwLinkMetadata,\n    ) -> tuple[jtp.Float, jtp.Matrix]:\n        \"\"\"\n        Compute the mass and inertia of a hardware link based on its metadata.\n\n        This function calculates the mass and inertia tensor of a hardware link\n        using its shape, dimensions, and density. The computation is performed\n        by using shape-specific methods.\n\n        Args:\n            hw_link_metadata: Metadata describing the hardware link,\n                including its shape, dimensions, and density.\n\n        Returns:\n            tuple: A tuple containing:\n                - mass: The computed mass of the hardware link.\n                - inertia: The computed inertia tensor of the hardware link.\n        \"\"\"\n\n        def box(dims, density, _moments) -> tuple[jtp.Float, jtp.Matrix]:\n            lx, ly, lz = dims\n\n            mass = density * lx * ly * lz\n\n            inertia = jnp.array(\n                [\n                    [mass * (ly**2 + lz**2) / 12, 0, 0],\n                    [0, mass * (lx**2 + lz**2) / 12, 0],\n                    [0, 0, mass * (lx**2 + ly**2) / 12],\n                ]\n            )\n            return mass, inertia\n\n        def cylinder(dims, density, _moments) -> tuple[jtp.Float, jtp.Matrix]:\n            r, l, _ = dims\n\n            mass = density * (jnp.pi * r**2 * l)\n\n            inertia = jnp.array(\n                [\n                    [mass * (3 * r**2 + l**2) / 12, 0, 0],\n                    [0, mass * (3 * r**2 + l**2) / 12, 0],\n                    [0, 0, mass * (r**2) / 2],\n                ]\n            )\n\n            return mass, inertia\n\n        def sphere(dims, density, _moments) -> tuple[jtp.Float, jtp.Matrix]:\n            r = dims[0]\n\n            mass = density * (4 / 3 * jnp.pi * r**3)\n\n            inertia = jnp.eye(3) * (2 / 5 * mass * r**2)\n\n            return mass, inertia\n\n        def mesh(dims, density, moments) -> tuple[jtp.Float, jtp.Matrix]:\n            return HwLinkMetadata.compute_mesh_inertia_from_moments(\n                moments, dims, density\n            )\n\n        def compute_mass_inertia(shape_idx, dims, density, moments):\n            def unsupported_case(_):\n                return (\n                    jnp.asarray(0.0, dtype=density.dtype),\n                    jnp.zeros((3, 3), dtype=density.dtype),\n                )\n\n            def supported_case(idx):\n                return jax.lax.switch(\n                    idx, (box, cylinder, sphere, mesh), dims, density, moments\n                )\n\n            return jax.lax.cond(\n                shape_idx < 0, unsupported_case, supported_case, shape_idx\n            )\n\n        masses, inertias = jax.vmap(compute_mass_inertia)(\n            hw_link_metadata.link_shape,\n            hw_link_metadata.geometry,\n            hw_link_metadata.density,\n            hw_link_metadata.mesh_moments,\n        )\n\n        return masses, inertias\n\n    @staticmethod\n    def _convert_scaling_to_3d_vector(\n        link_shapes: jtp.Int, scaling_factors: jtp.Vector\n    ) -> jtp.Vector:\n        \"\"\"\n        Convert scaling factors for specific shape dimensions into a 3D scaling vector.\n\n        Args:\n            link_shapes: The link_shapes of the link (e.g., box, sphere, cylinder, mesh).\n            scaling_factors: The scaling factors for the shape dimensions.\n\n        Returns:\n            A 3D scaling vector to apply to position vectors.\n\n        Note:\n            The scaling factors are applied as follows to generate the 3D scale vector:\n            - Box: [lx, ly, lz]\n            - Cylinder: [r, r, l]\n            - Sphere: [r, r, r]\n            - Mesh: [sx, sy, sz]\n        \"\"\"\n\n        # Index mapping for each shape type (link_shapes x 3 dims)\n        # Box: [lx, ly, lz] -> [0, 1, 2]\n        # Cylinder: [r, r, l] -> [0, 0, 1]\n        # Sphere: [r, r, r] -> [0, 0, 0]\n        # Mesh: [sx, sy, sz] -> [0, 1, 2]\n        shape_indices = jnp.array(\n            [\n                [0, 1, 2],  # Box\n                [0, 0, 1],  # Cylinder\n                [0, 0, 0],  # Sphere\n                [0, 1, 2],  # Mesh\n            ]\n        )\n\n        # For each link, get the index vector for its shape\n        per_link_indices = shape_indices[link_shapes]\n\n        # Gather dims per link according to per_link_indices\n        return scaling_factors.dims[per_link_indices.squeeze()]\n\n    @staticmethod\n    def compute_contact_points(\n        original_contact_params: jtp.Vector,\n        link_shapes: jtp.Vector,\n        original_com_positions: jtp.Vector,\n        updated_com_positions: jtp.Vector,\n        scaling_factors: ScalingFactors,\n    ) -> jtp.Matrix:\n        \"\"\"\n        Compute the new contact points based on the original contact parameters and\n        the scaling factors.\n\n        Args:\n            original_contact_params: The original contact parameters.\n            link_shapes: The shape types of the links (e.g., box, sphere, cylinder).\n            original_com_positions: The original center of mass positions of the links.\n            updated_com_positions: The updated center of mass positions of the links.\n            scaling_factors: The scaling factors for the link dimensions.\n\n        Returns:\n            The new contact points positions in the parent link frame.\n        \"\"\"\n\n        parent_link_indices = np.array(original_contact_params.body)\n\n        # Translate the original contact point positions in the origin, so\n        # that we can apply the scaling factors.\n        L_p_Ci = (\n            original_contact_params.point - original_com_positions[parent_link_indices]\n        )\n\n        # Extract the shape types of the parent links.\n        parent_link_shapes = jnp.array(link_shapes[parent_link_indices])\n\n        def sphere(parent_idx, L_p_C):\n            r = scaling_factors.dims[parent_idx][0]\n            return L_p_C * r\n\n        def cylinder(parent_idx, L_p_C):\n            # TODO: Cylinder collisions are not currently supported in JaxSim.\n            return L_p_C\n\n        def box(parent_idx, L_p_C):\n            lx, ly, lz = scaling_factors.dims[parent_idx]\n            return jnp.hstack(\n                [\n                    L_p_C[0] * lx,\n                    L_p_C[1] * ly,\n                    L_p_C[2] * lz,\n                ]\n            )\n\n        def mesh(parent_idx, L_p_C):\n            sx, sy, sz = scaling_factors.dims[parent_idx]\n            return jnp.hstack(\n                [\n                    L_p_C[0] * sx,\n                    L_p_C[1] * sy,\n                    L_p_C[2] * sz,\n                ]\n            )\n\n        new_positions = jax.vmap(\n            lambda shape_idx, parent_idx, L_p_C: jax.lax.switch(\n                shape_idx, (box, cylinder, sphere, mesh), parent_idx, L_p_C\n            )\n        )(\n            parent_link_shapes,\n            parent_link_indices,\n            L_p_Ci,\n        )\n\n        return new_positions + updated_com_positions[parent_link_indices]\n\n    @staticmethod\n    def compute_inertia_link(I_com, L_H_G) -> jtp.Matrix:\n        \"\"\"\n        Compute the inertia tensor of the link based on its shape and mass.\n        \"\"\"\n\n        L_R_G = L_H_G[:3, :3]\n        return L_R_G @ I_com @ L_R_G.T\n\n    @staticmethod\n    def apply_scaling(\n        has_joints: bool,\n        hw_metadata: HwLinkMetadata,\n        scaling_factors: ScalingFactors,\n    ) -> HwLinkMetadata:\n        \"\"\"\n        Apply scaling to the hardware parameters and return a new HwLinkMetadata object.\n\n        Args:\n            has_joints: A boolean indicating if the model has joints.\n            hw_metadata: the original HwLinkMetadata object.\n            scaling_factors: the scaling factors to apply.\n            has_joints: whether the model has at least one joint.\n\n        Returns:\n            A new HwLinkMetadata object with updated parameters.\n        \"\"\"\n\n        scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector(\n            hw_metadata.link_shape, scaling_factors\n        )\n\n        # =================================\n        # Update the kinematics of the link\n        # =================================\n\n        # Get the nominal transforms\n        L_H_G = hw_metadata.L_H_G\n        L_H_vis = hw_metadata.L_H_vis\n        L_H_pre_array = hw_metadata.L_H_pre\n        L_H_pre_mask = hw_metadata.L_H_pre_mask\n\n        # Express the transforms in the G frame\n        G_H_L = jaxsim.math.Transform.inverse(L_H_G)\n        G_H_vis = G_H_L @ L_H_vis\n\n        G_H_pre_array = (\n            jax.vmap(lambda L_H_pre: G_H_L @ L_H_pre)(L_H_pre_array)\n            if has_joints\n            else L_H_pre_array\n        )\n\n        # Apply the scaling to the position vectors\n        G_H̅_vis = G_H_vis.at[:3, 3].set(scale_vector * G_H_vis[:3, 3])\n\n        # Apply scaling to the position vectors in G_H_pre_array based on the mask\n        G_H̅_pre_array = (\n            G_H_pre_array.at[:, :3, 3].set(\n                jnp.where(\n                    L_H_pre_mask[:, None],\n                    scale_vector[None, :] * G_H_pre_array[:, :3, 3],\n                    G_H_pre_array[:, :3, 3],\n                )\n            )\n            if has_joints\n            else G_H_pre_array\n        )\n\n        # Get back to the link frame\n        L_H̅_G = L_H_G.at[:3, 3].set(scale_vector * L_H_G[:3, 3])\n        L_H̅_vis = L_H̅_G @ G_H̅_vis\n        L_H̅_pre_array = (\n            jax.vmap(lambda G_H̅_pre: L_H̅_G @ G_H̅_pre)(G_H̅_pre_array)\n            if has_joints\n            else G_H̅_pre_array\n        )\n\n        # ===========================\n        # Update the shape parameters\n        # ===========================\n\n        updated_geoms = hw_metadata.geometry * scaling_factors.dims\n\n        # =============================\n        # Scale the density of the link\n        # =============================\n\n        updated_density = hw_metadata.density * scaling_factors.density\n\n        # =============================\n        # Return updated HwLinkMetadata\n        # =============================\n\n        return hw_metadata.replace(\n            geometry=updated_geoms,\n            density=updated_density,\n            L_H_G=L_H̅_G,\n            L_H_vis=L_H̅_vis,\n            L_H_pre=L_H̅_pre_array,\n        )\n\n\n@jax_dataclasses.pytree_dataclass\nclass ScalingFactors(JaxsimDataclass):\n    \"\"\"\n    Class storing scaling factors for hardware parameters.\n\n    Attributes:\n        dims: Scaling factors for shape dimensions.\n        density: Scaling factor for density.\n    \"\"\"\n\n    dims: jtp.Vector\n    density: jtp.Float\n\n\n@dataclasses.dataclass(frozen=True)\nclass ConstraintType:\n    \"\"\"\n    Enumeration of all supported constraint types.\n    \"\"\"\n\n    Weld: ClassVar[int] = 0\n    # TODO: handle Connect constraint\n    # Connect: ClassVar[int] = 1\n\n\n@jax_dataclasses.pytree_dataclass\nclass ConstraintMap(JaxsimDataclass):\n    \"\"\"\n    Class storing the kinematic constraints of a model.\n    \"\"\"\n\n    frame_idxs_1: jtp.Int = dataclasses.field(\n        default_factory=lambda: jnp.array([], dtype=int)\n    )\n    frame_idxs_2: jtp.Int = dataclasses.field(\n        default_factory=lambda: jnp.array([], dtype=int)\n    )\n    constraint_types: jtp.Int = dataclasses.field(\n        default_factory=lambda: jnp.array([], dtype=int)\n    )\n    K_P: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array([], dtype=float)\n    )\n    K_D: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array([], dtype=float)\n    )\n    # Precomputed parent link indices for each constraint pair\n    parent_link_idxs_1: jtp.Int = dataclasses.field(\n        default_factory=lambda: jnp.array([], dtype=int)\n    )\n    parent_link_idxs_2: jtp.Int = dataclasses.field(\n        default_factory=lambda: jnp.array([], dtype=int)\n    )\n\n    def add_constraint(\n        self,\n        model: jaxsim.api.model.JaxSimModel,\n        frame_idx_1: int,\n        frame_idx_2: int,\n        constraint_type: int,\n        K_P: float | None = None,\n        K_D: float | None = None,\n    ) -> ConstraintMap:\n        \"\"\"\n        Add a constraint to the constraint map.\n\n        Args:\n            model: The model for which the constraints are added.\n            frame_idx_1: The index of the first frame.\n            frame_idx_2: The index of the second frame.\n            constraint_type: The type of constraint.\n            K_P: The proportional gain for Baumgarte stabilization (default: 1000).\n            K_D: The derivative gain for Baumgarte stabilization (default: 2 * sqrt(K_P)).\n\n        Returns:\n            A new ConstraintMap instance with the added constraint.\n\n        Note:\n            Since this method returns a new instance of ConstraintMap with the new constraint,\n            it will trigger recompilations in JIT-compiled functions.\n        \"\"\"\n\n        # Set default values for Baumgarte coefficients if not provided\n        if K_P is None:\n            K_P = jnp.array([1000.0])\n        if K_D is None:\n            K_D = 2 * jnp.sqrt(K_P)\n\n        # Create new arrays with the input elements appended\n        new_frame_idxs_1 = jnp.append(self.frame_idxs_1, frame_idx_1)\n        new_frame_idxs_2 = jnp.append(self.frame_idxs_2, frame_idx_2)\n        new_constraint_types = jnp.append(self.constraint_types, constraint_type)\n        new_K_P = jnp.append(self.K_P, K_P)\n        new_K_D = jnp.append(self.K_D, K_D)\n\n        # Compute parent link indices (now always available since model is required)\n        parent_link_idx_1 = jaxsim.api.frame.idx_of_parent_link(\n            model, frame_index=frame_idx_1\n        )\n        parent_link_idx_2 = jaxsim.api.frame.idx_of_parent_link(\n            model, frame_index=frame_idx_2\n        )\n        new_parent_link_idxs_1 = jnp.append(self.parent_link_idxs_1, parent_link_idx_1)\n        new_parent_link_idxs_2 = jnp.append(self.parent_link_idxs_2, parent_link_idx_2)\n\n        # Return a new ConstraintMap object with updated attributes\n        return ConstraintMap(\n            frame_idxs_1=new_frame_idxs_1,\n            frame_idxs_2=new_frame_idxs_2,\n            constraint_types=new_constraint_types,\n            K_P=new_K_P,\n            K_D=new_K_D,\n            parent_link_idxs_1=new_parent_link_idxs_1,\n            parent_link_idxs_2=new_parent_link_idxs_2,\n        )\n"
  },
  {
    "path": "src/jaxsim/api/link.py",
    "content": "import functools\nfrom collections.abc import Sequence\n\nimport jax\nimport jax.numpy as jnp\nimport jax.scipy.linalg\nimport numpy as np\n\nimport jaxsim.api as js\nimport jaxsim.rbda\nimport jaxsim.typing as jtp\nfrom jaxsim import exceptions\nfrom jaxsim.math import Adjoint\n\nfrom .common import VelRepr\n\n# =======================\n# Index-related functions\n# =======================\n\n\n@functools.partial(jax.jit, static_argnames=\"link_name\")\ndef name_to_idx(model: js.model.JaxSimModel, *, link_name: str) -> jtp.Int:\n    \"\"\"\n    Convert the name of a link to its index.\n\n    Args:\n        model: The model to consider.\n        link_name: The name of the link.\n\n    Returns:\n        The index of the link.\n    \"\"\"\n\n    if link_name not in model.link_names():\n        raise ValueError(f\"Link '{link_name}' not found in the model.\")\n\n    return (\n        jnp.array(model.kin_dyn_parameters.link_names.index(link_name))\n        .astype(int)\n        .squeeze()\n    )\n\n\ndef idx_to_name(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> str:\n    \"\"\"\n    Convert the index of a link to its name.\n\n    Args:\n        model: The model to consider.\n        link_index: The index of the link.\n\n    Returns:\n        The name of the link.\n    \"\"\"\n\n    exceptions.raise_value_error_if(\n        condition=link_index < 0,\n        msg=\"Invalid link index '{idx}'\",\n        idx=link_index,\n    )\n\n    return model.kin_dyn_parameters.link_names[link_index]\n\n\n@functools.partial(jax.jit, static_argnames=\"link_names\")\ndef names_to_idxs(\n    model: js.model.JaxSimModel, *, link_names: Sequence[str]\n) -> jax.Array:\n    \"\"\"\n    Convert a sequence of link names to their corresponding indices.\n\n    Args:\n        model: The model to consider.\n        link_names: The names of the links.\n\n    Returns:\n        The indices of the links.\n    \"\"\"\n\n    return jnp.array(\n        [name_to_idx(model=model, link_name=name) for name in link_names],\n    ).astype(int)\n\n\ndef idxs_to_names(\n    model: js.model.JaxSimModel, *, link_indices: Sequence[jtp.IntLike] | jtp.VectorLike\n) -> tuple[str, ...]:\n    \"\"\"\n    Convert a sequence of link indices to their corresponding names.\n\n    Args:\n        model: The model to consider.\n        link_indices: The indices of the links.\n\n    Returns:\n        The names of the links.\n    \"\"\"\n\n    return tuple(np.array(model.kin_dyn_parameters.link_names)[list(link_indices)])\n\n\n# =========\n# Link APIs\n# =========\n\n\n@jax.jit\ndef mass(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Float:\n    \"\"\"\n    Return the mass of the link.\n\n    Args:\n        model: The model to consider.\n        link_index: The index of the link.\n\n    Returns:\n        The mass of the link.\n    \"\"\"\n\n    exceptions.raise_value_error_if(\n        condition=jnp.array(\n            [link_index < 0, link_index >= model.number_of_links()]\n        ).any(),\n        msg=\"Invalid link index '{idx}'\",\n        idx=link_index,\n    )\n\n    return model.kin_dyn_parameters.link_parameters.mass[link_index].astype(float)\n\n\n@jax.jit\ndef spatial_inertia(\n    model: js.model.JaxSimModel, *, link_index: jtp.IntLike\n) -> jtp.Matrix:\n    r\"\"\"\n    Compute the 6D spatial inertial of the link.\n\n    Args:\n        model: The model to consider.\n        link_index: The index of the link.\n\n    Returns:\n        The :math:`6 \\times 6` matrix representing the spatial inertia of the link expressed in\n        the link frame (body-fixed representation).\n    \"\"\"\n\n    exceptions.raise_value_error_if(\n        condition=jnp.array(\n            [link_index < 0, link_index >= model.number_of_links()]\n        ).any(),\n        msg=\"Invalid link index '{idx}'\",\n        idx=link_index,\n    )\n\n    link_parameters = jax.tree.map(\n        lambda l: l[link_index], model.kin_dyn_parameters.link_parameters\n    )\n\n    return js.kin_dyn_parameters.LinkParameters.spatial_inertia(link_parameters)\n\n\n@jax.jit\ndef transform(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    link_index: jtp.IntLike,\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the SE(3) transform from the world frame to the link frame.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        link_index: The index of the link.\n\n    Returns:\n        The 4x4 matrix representing the transform.\n    \"\"\"\n\n    exceptions.raise_value_error_if(\n        condition=jnp.array(\n            [link_index < 0, link_index >= model.number_of_links()]\n        ).any(),\n        msg=\"Invalid link index '{idx}'\",\n        idx=link_index,\n    )\n\n    return data._link_transforms[link_index]\n\n\n@jax.jit\ndef com_position(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    link_index: jtp.IntLike,\n    in_link_frame: jtp.BoolLike = True,\n) -> jtp.Vector:\n    \"\"\"\n    Compute the position of the center of mass of the link.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        link_index: The index of the link.\n        in_link_frame:\n            Whether to return the position in the link frame or in the world frame.\n\n    Returns:\n        The 3D position of the center of mass of the link.\n    \"\"\"\n\n    from jaxsim.math.inertia import Inertia\n\n    _, L_p_CoM, _ = Inertia.to_params(\n        M=spatial_inertia(model=model, link_index=link_index)\n    )\n\n    def com_in_link_frame():\n        return L_p_CoM.squeeze()\n\n    def com_in_inertial_frame():\n        W_H_L = transform(link_index=link_index, model=model, data=data)\n        W_p̃_CoM = W_H_L @ jnp.hstack([L_p_CoM.squeeze(), 1])\n\n        return W_p̃_CoM[0:3].squeeze()\n\n    return jax.lax.select(\n        pred=in_link_frame,\n        on_true=com_in_link_frame(),\n        on_false=com_in_inertial_frame(),\n    )\n\n\n@functools.partial(jax.jit, static_argnames=[\"output_vel_repr\"])\ndef jacobian(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    link_index: jtp.IntLike,\n    output_vel_repr: VelRepr | None = None,\n) -> jtp.Matrix:\n    r\"\"\"\n    Compute the free-floating jacobian of the link.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        link_index: The index of the link.\n        output_vel_repr:\n            The output velocity representation of the free-floating jacobian.\n\n    Returns:\n        The :math:`6 \\times (6+n)` free-floating jacobian of the link.\n\n    Note:\n        The input representation of the free-floating jacobian is the active\n        velocity representation.\n    \"\"\"\n\n    exceptions.raise_value_error_if(\n        condition=jnp.array(\n            [link_index < 0, link_index >= model.number_of_links()]\n        ).any(),\n        msg=\"Invalid link index '{idx}'\",\n        idx=link_index,\n    )\n\n    output_vel_repr = (\n        output_vel_repr if output_vel_repr is not None else data.velocity_representation\n    )\n\n    # Compute the doubly-left free-floating full jacobian.\n    B_J_full_WX_B, B_H_Li = jaxsim.rbda.jacobian_full_doubly_left(\n        model=model,\n        joint_positions=data.joint_positions,\n    )\n\n    # Compute the actual doubly-left free-floating jacobian of the link.\n    κb = model.kin_dyn_parameters.support_body_array_bool[link_index]\n    B_J_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J_full_WX_B\n\n    # Adjust the input representation such that `J_WL_I @ I_ν`.\n    match data.velocity_representation:\n        case VelRepr.Inertial:\n            W_H_B = data._base_transform\n            B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True)\n            B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag(  # noqa: F841\n                B_X_W, jnp.eye(model.dofs())\n            )\n\n        case VelRepr.Body:\n            B_J_WL_I = B_J_WL_B\n\n        case VelRepr.Mixed:\n            W_R_B = jaxsim.math.Quaternion.to_dcm(data.base_orientation)\n            BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)\n            B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)\n            B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag(  # noqa: F841\n                B_X_BW, jnp.eye(model.dofs())\n            )\n\n        case _:\n            raise ValueError(data.velocity_representation)\n\n    B_H_L = B_H_Li[link_index]\n\n    # Adjust the output representation such that `O_v_WL_I = O_J_WL_I @ I_ν`.\n    match output_vel_repr:\n        case VelRepr.Inertial:\n            W_H_B = data._base_transform\n            W_X_B = Adjoint.from_transform(transform=W_H_B)\n            O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I  # noqa: F841\n\n        case VelRepr.Body:\n            L_X_B = Adjoint.from_transform(transform=B_H_L, inverse=True)\n            L_J_WL_I = L_X_B @ B_J_WL_I\n            O_J_WL_I = L_J_WL_I\n\n        case VelRepr.Mixed:\n            W_H_B = data._base_transform\n            W_H_L = W_H_B @ B_H_L\n            LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3))\n            LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)\n            LW_X_B = Adjoint.from_transform(transform=LW_H_B)\n            LW_J_WL_I = LW_X_B @ B_J_WL_I\n            O_J_WL_I = LW_J_WL_I\n\n        case _:\n            raise ValueError(output_vel_repr)\n\n    return O_J_WL_I\n\n\n@functools.partial(jax.jit, static_argnames=[\"output_vel_repr\"])\ndef velocity(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    link_index: jtp.IntLike,\n    output_vel_repr: VelRepr | None = None,\n) -> jtp.Vector:\n    \"\"\"\n    Compute the 6D velocity of the link.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        link_index: The index of the link.\n        output_vel_repr:\n            The output velocity representation of the link velocity.\n\n    Returns:\n        The 6D velocity of the link in the specified velocity representation.\n    \"\"\"\n\n    exceptions.raise_value_error_if(\n        condition=jnp.array(\n            [link_index < 0, link_index >= model.number_of_links()]\n        ).any(),\n        msg=\"Invalid link index '{idx}'\",\n        idx=link_index,\n    )\n\n    output_vel_repr = (\n        output_vel_repr if output_vel_repr is not None else data.velocity_representation\n    )\n\n    # Get the link jacobian having I as input representation (taken from data)\n    # and O as output representation, specified by the user (or taken from data).\n    O_J_WL_I = jacobian(\n        model=model,\n        data=data,\n        link_index=link_index,\n        output_vel_repr=output_vel_repr,\n    )\n\n    # Get the generalized velocity in the input velocity representation.\n    I_ν = data.generalized_velocity\n\n    # Compute the link velocity in the output velocity representation.\n    return O_J_WL_I @ I_ν\n\n\n@functools.partial(jax.jit, static_argnames=[\"output_vel_repr\"])\ndef jacobian_derivative(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    link_index: jtp.IntLike,\n    output_vel_repr: VelRepr | None = None,\n) -> jtp.Matrix:\n    r\"\"\"\n    Compute the derivative of the free-floating jacobian of the link.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        link_index: The index of the link.\n        output_vel_repr:\n            The output velocity representation of the free-floating jacobian derivative.\n\n    Returns:\n        The derivative of the :math:`6 \\times (6+n)` free-floating jacobian of the link.\n\n    Note:\n        The input representation of the free-floating jacobian derivative is the active\n        velocity representation.\n    \"\"\"\n\n    exceptions.raise_value_error_if(\n        condition=jnp.array(\n            [link_index < 0, link_index >= model.number_of_links()]\n        ).any(),\n        msg=\"Invalid link index '{idx}'\",\n        idx=link_index,\n    )\n\n    output_vel_repr = (\n        output_vel_repr if output_vel_repr is not None else data.velocity_representation\n    )\n\n    O_J̇_WL_I = js.model.generalized_free_floating_jacobian_derivative(\n        model=model, data=data, output_vel_repr=output_vel_repr\n    )[link_index]\n\n    return O_J̇_WL_I\n\n\n@jax.jit\ndef bias_acceleration(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    link_index: jtp.IntLike,\n) -> jtp.Vector:\n    \"\"\"\n    Compute the bias acceleration of the link.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        link_index: The index of the link.\n\n    Returns:\n        The 6D bias acceleration of the link.\n    \"\"\"\n\n    exceptions.raise_value_error_if(\n        condition=jnp.array(\n            [link_index < 0, link_index >= model.number_of_links()]\n        ).any(),\n        msg=\"Invalid link index '{idx}'\",\n        idx=link_index,\n    )\n\n    # Compute the bias acceleration of all links in the active representation.\n    O_v̇_WL = js.model.link_bias_accelerations(model=model, data=data)[link_index]\n    return O_v̇_WL\n"
  },
  {
    "path": "src/jaxsim/api/model.py",
    "content": "from __future__ import annotations\n\nimport copy\nimport dataclasses\nimport enum\nimport functools\nimport pathlib\nfrom collections.abc import Sequence\n\nimport jax\nimport jax.numpy as jnp\nimport jax_dataclasses\nimport numpy as np\nimport rod\nfrom jax_dataclasses import Static\nfrom rod.urdf.exporter import UrdfExporter\n\nimport jaxsim.api as js\nimport jaxsim.terrain\nimport jaxsim.typing as jtp\nfrom jaxsim import logging\nfrom jaxsim.api.kin_dyn_parameters import (\n    HwLinkMetadata,\n    KinDynParameters,\n    LinkParameters,\n    LinkParametrizableShape,\n    ScalingFactors,\n)\nfrom jaxsim.math import Adjoint, Cross, Skew\nfrom jaxsim.parsers.descriptions import ModelDescription\nfrom jaxsim.parsers.descriptions.joint import JointDescription\nfrom jaxsim.parsers.descriptions.link import LinkDescription\nfrom jaxsim.parsers.rod.utils import prepare_mesh_for_parametrization\nfrom jaxsim.utils import JaxsimDataclass, Mutability, wrappers\n\nfrom .common import VelRepr\n\n\nclass IntegratorType(enum.IntEnum):\n    \"\"\"The integrators available for the simulation.\"\"\"\n\n    SemiImplicitEuler = enum.auto()\n    RungeKutta4 = enum.auto()\n    RungeKutta4Fast = enum.auto()\n\n\n@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)\nclass JaxSimModel(JaxsimDataclass):\n    \"\"\"\n    The JaxSim model defining the kinematics and dynamics of a robot.\n    \"\"\"\n\n    model_name: Static[str]\n\n    time_step: float = dataclasses.field(\n        default=0.001,\n    )\n\n    terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(\n        default_factory=jaxsim.terrain.FlatTerrain.build, repr=False\n    )\n\n    gravity: Static[float] = -jaxsim.math.STANDARD_GRAVITY\n\n    contact_model: Static[jaxsim.rbda.contacts.ContactModel | None] = dataclasses.field(\n        default=None, repr=False\n    )\n\n    contact_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(\n        default=None, repr=False\n    )\n\n    actuation_params: Static[jaxsim.rbda.actuation.ActuationParams] = dataclasses.field(\n        default=None, repr=False\n    )\n\n    kin_dyn_parameters: js.kin_dyn_parameters.KinDynParameters | None = (\n        dataclasses.field(default=None, repr=False)\n    )\n\n    integrator: Static[IntegratorType] = dataclasses.field(\n        default=IntegratorType.SemiImplicitEuler, repr=False\n    )\n\n    built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(\n        default=None, repr=False\n    )\n\n    _description: Static[wrappers.HashlessObject[ModelDescription | None]] = (\n        dataclasses.field(default=None, repr=False)\n    )\n\n    @property\n    def description(self) -> ModelDescription:\n        \"\"\"\n        Return the model description.\n        \"\"\"\n        return self._description.get()\n\n    def __eq__(self, other: JaxSimModel) -> bool:\n        if not isinstance(other, JaxSimModel):\n            return False\n\n        if self.model_name != other.model_name:\n            return False\n\n        if self.time_step != other.time_step:\n            return False\n\n        if self.kin_dyn_parameters != other.kin_dyn_parameters:\n            return False\n\n        return True\n\n    def __hash__(self) -> int:\n        return hash(\n            (\n                hash(self.model_name),\n                hash(self.time_step),\n                hash(self.kin_dyn_parameters),\n                hash(self.contact_model),\n            )\n        )\n\n    # ========================\n    # Initialization and state\n    # ========================\n\n    @classmethod\n    def build_from_model_description(\n        cls,\n        model_description: str | pathlib.Path | rod.Model,\n        *,\n        model_name: str | None = None,\n        time_step: jtp.FloatLike | None = None,\n        terrain: jaxsim.terrain.Terrain | None = None,\n        contact_model: jaxsim.rbda.contacts.ContactModel | None = None,\n        contact_params: jaxsim.rbda.contacts.ContactsParams | None = None,\n        actuation_params: jaxsim.rbda.actuation.ActuationParams | None = None,\n        integrator: IntegratorType | None = None,\n        is_urdf: bool | None = None,\n        considered_joints: Sequence[str] | None = None,\n        gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,\n        constraints: jaxsim.rbda.kinematic_constraints.ConstraintMap | None = None,\n        parametrized_links: tuple[str, ...] | None = None,\n    ) -> JaxSimModel:\n        \"\"\"\n        Build a Model object from a model description.\n\n        Args:\n            model_description:\n                A path to an SDF/URDF file, a string containing\n                its content, or a pre-parsed/pre-built rod model.\n            model_name:\n                The name of the model. If not specified, it is read from the description.\n            time_step:\n                The default time step to consider for the simulation. It can be\n                manually overridden in the function that steps the simulation.\n            terrain: The terrain to consider (the default is a flat infinite plane).\n            contact_model:\n                The contact model to consider.\n                If not specified, a soft contacts model is used.\n            contact_params: The parameters of the contact model.\n            actuation_params: The parameters of the actuation model.\n            integrator: The integrator to use for the simulation.\n            is_urdf:\n                The optional flag to force the model description to be parsed as a URDF.\n                This is usually automatically inferred.\n            considered_joints:\n                The list of joints to consider. If None, all joints are considered.\n            gravity: The gravity constant. Normally passed as a positive value.\n            constraints:\n                An object of type ConstraintMap containing the kinematic constraints to consider. If None, no constraints are considered.\n                Note that constraints can be used only with RelaxedRigidContacts.\n            parametrized_links:\n                The optional list of links to be parametrized. If None, all links are parametrized.\n\n        Returns:\n            The built Model object.\n        \"\"\"\n\n        import jaxsim.parsers.rod\n\n        # Parse the input resource (either a path to file or a string with the URDF/SDF)\n        # and build the -intermediate- model description.\n        intermediate_description = jaxsim.parsers.rod.build_model_description(\n            model_description=model_description, is_urdf=is_urdf\n        )\n\n        # Lump links together if not all joints are considered.\n        # Note: this procedure assigns a zero position to all joints not considered.\n        if considered_joints is not None:\n            intermediate_description = intermediate_description.reduce(\n                considered_joints=considered_joints\n            )\n\n        # Build the model.\n        model = cls.build(\n            model_description=intermediate_description,\n            model_name=model_name,\n            time_step=time_step,\n            terrain=terrain,\n            contact_model=contact_model,\n            actuation_params=actuation_params,\n            contact_params=contact_params,\n            integrator=integrator,\n            gravity=-gravity,\n            constraints=constraints,\n            parametrized_links=parametrized_links,\n        )\n\n        # Store the origin of the model, in case downstream logic needs it.\n        with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):\n            model.built_from = model_description\n\n        # Compute the hw parametrization metadata of the model\n        # TODO: move the building of the metadata to KinDynParameters.build()\n        #       and use the model_description instead of model.built_from.\n        with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):\n            model.kin_dyn_parameters.hw_link_metadata = model.compute_hw_link_metadata(\n                parametrized_links=parametrized_links\n            )\n\n        return model\n\n    @classmethod\n    def build(\n        cls,\n        model_description: ModelDescription,\n        *,\n        model_name: str | None = None,\n        time_step: jtp.FloatLike | None = None,\n        terrain: jaxsim.terrain.Terrain | None = None,\n        contact_model: jaxsim.rbda.contacts.ContactModel | None = None,\n        contact_params: jaxsim.rbda.contacts.ContactsParams | None = None,\n        actuation_params: jaxsim.rbda.actuation.ActuationParams | None = None,\n        integrator: IntegratorType | None = None,\n        gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,\n        constraints: jaxsim.rbda.kinematic_constraints.ConstraintMap | None = None,\n        parametrized_links: tuple[str, ...] | None = None,\n    ) -> JaxSimModel:\n        \"\"\"\n        Build a Model object from an intermediate model description.\n\n        Args:\n            model_description:\n                The intermediate model description defining the kinematics and dynamics\n                of the model.\n            model_name:\n                The name of the model. If not specified, it is read from the description.\n            time_step:\n                The default time step to consider for the simulation. It can be\n                manually overridden in the function that steps the simulation.\n            terrain: The terrain to consider (the default is a flat infinite plane).\n                The optional name of the model overriding the physics model name.\n            contact_model:\n                The contact model to consider.\n                If not specified, a soft contact model is used.\n            contact_params: The parameters of the contact model.\n            actuation_params: The parameters of the actuation model.\n            integrator: The integrator to use for the simulation.\n            gravity: The gravity constant.\n            constraints:\n                An object of type ConstraintMap containing the kinematic constraints to consider. If None, no constraints are considered.\n            parametrized_links:\n                The optional list of links to be parametrized. If None, all links are parametrized.\n\n        Returns:\n            The built Model object.\n        \"\"\"\n\n        # Set the model name (if not provided, use the one from the model description).\n        model_name = model_name if model_name is not None else model_description.name\n\n        # Consider the default terrain (a flat infinite plane) if not specified.\n        terrain = (\n            terrain\n            if terrain is not None\n            else JaxSimModel.__dataclass_fields__[\"terrain\"].default_factory()\n        )\n\n        # Consider the default time step if not specified.\n        time_step = (\n            time_step\n            if time_step is not None\n            else JaxSimModel.__dataclass_fields__[\"time_step\"].default\n        )\n\n        # Create the default contact model.\n        # It will be populated with an initial estimation of good parameters.\n        # While these might not be the best, they are a good starting point.\n        contact_model = (\n            contact_model\n            if contact_model is not None\n            else jaxsim.rbda.contacts.SoftContacts.build()\n        )\n\n        if contact_params is None:\n            contact_params = contact_model._parameters_class()\n\n        if actuation_params is None:\n            actuation_params = jaxsim.rbda.actuation.ActuationParams()\n\n        # Consider the default integrator if not specified.\n        integrator = (\n            integrator\n            if integrator is not None\n            else JaxSimModel.__dataclass_fields__[\"integrator\"].default\n        )\n\n        # Build the model.\n        model = cls(\n            model_name=model_name,\n            kin_dyn_parameters=js.kin_dyn_parameters.KinDynParameters.build(\n                model_description=model_description, constraints=constraints\n            ),\n            time_step=time_step,\n            terrain=terrain,\n            contact_model=contact_model,\n            contact_params=contact_params,\n            actuation_params=actuation_params,\n            integrator=integrator,\n            gravity=gravity,\n            # The following is wrapped as hashless since it's a static argument, and we\n            # don't want to trigger recompilation if it changes. All relevant parameters\n            # needed to compute kinematics and dynamics quantities are stored in the\n            # kin_dyn_parameters attribute.\n            _description=wrappers.HashlessObject(obj=model_description),\n        )\n\n        return model\n\n    def compute_hw_link_metadata(\n        self, parametrized_links: tuple[str, ...] | None = None\n    ) -> HwLinkMetadata:\n        \"\"\"\n        Compute the parametric metadata of the links in the model.\n\n        Args:\n            parametrized_links:\n                An optional tuple of link names to be parametrized. If None,\n                all links will be parametrized.\n\n        Returns:\n            An instance of HwLinkMetadata containing the metadata of all links.\n        \"\"\"\n        model_description = self.description\n\n        # Get ordered links and joints from the model description\n        ordered_links: list[LinkDescription] = sorted(\n            list(model_description.links_dict.values()),\n            key=lambda l: l.index,\n        )\n        ordered_joints: list[JointDescription] = sorted(\n            list(model_description.joints_dict.values()),\n            key=lambda j: j.index,\n        )\n\n        # Ensure the model was built from a valid source\n        rod_model = None\n        match self.built_from:\n            case str() | pathlib.Path():\n                rod_model = rod.Sdf.load(sdf=self.built_from).models()[0]\n                assert rod_model.name == self.name()\n            case rod.Model():\n                rod_model = self.built_from\n            case _:\n                logging.debug(\n                    f\"Invalid type for model.built_from ({type(self.built_from)}).\"\n                    \"Skipping for hardware parametrization.\"\n                )\n                return HwLinkMetadata.empty()\n\n        # Use URDF frame convention for consistent pose representation\n        rod_model.switch_frame_convention(\n            frame_convention=rod.FrameConvention.Urdf, explicit_frames=True\n        )\n\n        rod_links_dict = {link.name: link for link in rod_model.links()}\n\n        # Initialize lists to collect metadata for all links\n        shapes = []\n        geoms = []\n        densities = []\n        L_H_Gs = []\n        L_H_vises = []\n        L_H_pre_masks = []\n        L_H_pre = []\n        mesh_moments_list = []\n        mesh_vertices = []\n        mesh_faces = []\n        mesh_offsets = []\n        mesh_uris = []\n\n        # Process each link, only parametrizing those in parametrized_links if provided\n        for link_description in ordered_links:\n            link_name = link_description.name\n\n            if parametrized_links is not None and link_name not in parametrized_links:\n                # Mark as unsupported for non-parametrized links\n                shapes.append(LinkParametrizableShape.Unsupported)\n                geoms.append([0, 0, 0])\n                densities.append(0.0)\n                L_H_Gs.append(jnp.eye(4))\n                L_H_vises.append(jnp.eye(4))\n                L_H_pre_masks.append([0] * self.number_of_joints())\n                L_H_pre.append([jnp.eye(4)] * self.number_of_joints())\n                mesh_vertices.append(None)\n                mesh_faces.append(None)\n                mesh_offsets.append(None)\n                mesh_uris.append(None)\n                mesh_moments_list.append(np.zeros(13))\n                continue\n\n            rod_link = rod_links_dict.get(link_name)\n            link_index = int(js.link.name_to_idx(model=self, link_name=link_name))\n\n            # Get child joints for the link\n            child_joints_indices = [\n                js.joint.name_to_idx(model=self, joint_name=j.name)\n                for j in ordered_joints\n                if j.parent.name == link_name\n            ]\n\n            # Skip unsupported links\n            if not jnp.allclose(\n                self.kin_dyn_parameters.joint_model.suc_H_i[link_index],\n                jnp.eye(4),\n                **(dict(atol=1e-6) if not jax.config.jax_enable_x64 else {}),\n            ):\n                logging.debug(\n                    f\"Skipping link '{link_name}' for hardware parametrization due to unsupported suc_H_link.\"\n                )\n                rod_link = None\n\n            # Compute density and dimensions\n            mass = float(self.kin_dyn_parameters.link_parameters.mass[link_index])\n\n            # Find the first supported visual\n            supported_visual = (\n                next(\n                    (\n                        v\n                        for v in rod_link.visuals()\n                        if isinstance(\n                            v.geometry.geometry(),\n                            (rod.Box, rod.Sphere, rod.Cylinder, rod.Mesh),\n                        )\n                    ),\n                    None,\n                )\n                if rod_link\n                else None\n            )\n\n            geometry = (\n                supported_visual.geometry.geometry() if supported_visual else None\n            )\n\n            if isinstance(geometry, rod.Box):\n                lx, ly, lz = geometry.size\n                density = mass / (lx * ly * lz)\n                geom = [lx, ly, lz]\n                shape = LinkParametrizableShape.Box\n                mesh_vertices.append(None)\n                mesh_faces.append(None)\n                mesh_offsets.append(None)\n                mesh_uris.append(None)\n                mesh_moments_list.append(np.zeros(13))\n            elif isinstance(geometry, rod.Sphere):\n                r = geometry.radius\n                density = mass / (4 / 3 * jnp.pi * r**3)\n                geom = [r, 0, 0]\n                shape = LinkParametrizableShape.Sphere\n                mesh_vertices.append(None)\n                mesh_faces.append(None)\n                mesh_offsets.append(None)\n                mesh_uris.append(None)\n                mesh_moments_list.append(np.zeros(13))\n            elif isinstance(geometry, rod.Cylinder):\n                r, l = geometry.radius, geometry.length\n                density = mass / (jnp.pi * r**2 * l)\n                geom = [r, l, 0]\n                shape = LinkParametrizableShape.Cylinder\n                mesh_vertices.append(None)\n                mesh_faces.append(None)\n                mesh_offsets.append(None)\n                mesh_uris.append(None)\n                mesh_moments_list.append(np.zeros(13))\n            elif isinstance(geometry, rod.Mesh):\n                # Load and prepare mesh for parametric scaling\n                try:\n\n                    mesh_data = prepare_mesh_for_parametrization(\n                        mesh_uri=geometry.uri,\n                        scale=geometry.scale,\n                    )\n\n                    density = (\n                        mass / mesh_data[\"volume\"] if mesh_data[\"volume\"] > 0 else 0.0\n                    )\n\n                    # For meshes, store cumulative scale factors (initially 1.0) in geometry\n                    # instead of bounding box extents. This allows proper multiplicative scaling.\n                    geom = [1.0, 1.0, 1.0]\n                    shape = LinkParametrizableShape.Mesh\n\n                    # Store mesh data\n                    mesh_vertices.append(mesh_data[\"vertices\"])\n                    mesh_faces.append(mesh_data[\"faces\"])\n                    mesh_offsets.append(mesh_data[\"offset\"])\n                    mesh_uris.append(mesh_data[\"uri\"])\n\n                    # Precompute volumetric moments for JIT-friendly inertia computation\n                    mesh_moments_list.append(\n                        HwLinkMetadata.precompute_mesh_moments(\n                            mesh_data[\"vertices\"], mesh_data[\"faces\"]\n                        )\n                    )\n\n                    logging.info(\n                        f\"Loaded mesh for link '{link_name}': \"\n                        f\"{len(mesh_data['vertices'])} vertices, \"\n                        f\"{len(mesh_data['faces'])} faces, \"\n                    )\n                except Exception as e:\n                    logging.warning(\n                        f\"Failed to load mesh for link '{link_name}': {e}. \"\n                        f\"Marking as unsupported.\"\n                    )\n                    density = 0.0\n                    geom = [0, 0, 0]\n                    shape = LinkParametrizableShape.Unsupported\n                    mesh_vertices.append(None)\n                    mesh_faces.append(None)\n                    mesh_offsets.append(None)\n                    mesh_uris.append(None)\n                    mesh_moments_list.append(np.zeros(13))\n            else:\n                logging.debug(\n                    f\"Skipping link '{link_name}' for hardware parametrization due to unsupported geometry.\"\n                )\n                density = 0.0\n                geom = [0, 0, 0]\n                shape = LinkParametrizableShape.Unsupported\n                mesh_vertices.append(None)\n                mesh_faces.append(None)\n                mesh_offsets.append(None)\n                mesh_uris.append(None)\n                mesh_moments_list.append(np.zeros(13))\n\n            inertial_pose = (\n                rod_link.inertial.pose.transform() if rod_link else jnp.eye(4)\n            )\n            visual_pose = (\n                supported_visual.pose.transform() if supported_visual else jnp.eye(4)\n            )\n            l_h_pre_mask = [\n                int(joint_index in child_joints_indices)\n                for joint_index in range(self.number_of_joints())\n            ]\n            l_h_pre = [\n                (\n                    self.kin_dyn_parameters.joint_model.λ_H_pre[joint_index + 1]\n                    if joint_index in child_joints_indices\n                    else jnp.eye(4)\n                )\n                for joint_index in range(self.number_of_joints())\n            ]\n\n            shapes.append(shape)\n            geoms.append(geom)\n            densities.append(density)\n            L_H_Gs.append(inertial_pose)\n            L_H_vises.append(visual_pose)\n            L_H_pre_masks.append(l_h_pre_mask)\n            L_H_pre.append(l_h_pre)\n\n        if np.all(np.array(shapes) == LinkParametrizableShape.Unsupported):\n            logging.debug(\n                \"All links were skipped for hardware parametrization. Returning empty metadata.\"\n            )\n            return HwLinkMetadata.empty()\n\n        # Stack collected data into JAX arrays\n        # Handle L_H_pre specially: ensure shape (n_links, n_joints, 4, 4) even when n_joints=0\n        L_H_pre_array = jnp.array(L_H_pre, dtype=float)\n        if self.number_of_joints() == 0:\n            # Reshape from (n_links, 0) to (n_links, 0, 4, 4)\n            n_links = len(L_H_pre)\n            L_H_pre_array = L_H_pre_array.reshape(n_links, 0, 4, 4)\n\n        return HwLinkMetadata(\n            link_shape=jnp.array(shapes, dtype=int),\n            geometry=jnp.array(geoms, dtype=float),\n            density=jnp.array(densities, dtype=float),\n            L_H_G=jnp.array(L_H_Gs, dtype=float),\n            L_H_vis=jnp.array(L_H_vises, dtype=float),\n            L_H_pre_mask=jnp.array(L_H_pre_masks, dtype=bool),\n            L_H_pre=L_H_pre_array,\n            mesh_moments=jnp.array(np.stack(mesh_moments_list), dtype=float),\n            mesh_vertices=(\n                tuple(\n                    wrappers.HashedNumpyArray(array=v) if v is not None else None\n                    for v in mesh_vertices\n                )\n                if any(v is not None for v in mesh_vertices)\n                else None\n            ),\n            mesh_faces=(\n                tuple(\n                    wrappers.HashedNumpyArray(array=f) if f is not None else None\n                    for f in mesh_faces\n                )\n                if any(f is not None for f in mesh_faces)\n                else None\n            ),\n            mesh_offset=(\n                tuple(\n                    wrappers.HashedNumpyArray(array=o) if o is not None else None\n                    for o in mesh_offsets\n                )\n                if any(o is not None for o in mesh_offsets)\n                else None\n            ),\n            mesh_uri=(\n                tuple(mesh_uris) if any(u is not None for u in mesh_uris) else None\n            ),\n        )\n\n    def export_updated_model(self) -> str:\n        \"\"\"\n        Export the JaxSim model to URDF with the current hardware parameters.\n\n        Returns:\n            The URDF string of the updated model.\n\n        Note:\n            This method is not meant to be used in JIT-compiled functions.\n        \"\"\"\n\n        if isinstance(jnp.zeros(0), jax.core.Tracer):\n            raise RuntimeError(\"This method cannot be used in JIT-compiled functions\")\n\n        # Ensure `built_from` is a ROD model and create `rod_model_output`\n        if isinstance(self.built_from, rod.Model):\n            rod_model_output = copy.deepcopy(self.built_from)\n        elif isinstance(self.built_from, (str, pathlib.Path)):\n            rod_model_output = rod.Sdf.load(sdf=self.built_from).models()[0]\n        else:\n            raise ValueError(\n                \"The JaxSim model must be built from a valid ROD model source\"\n            )\n\n        # Switch to URDF frame convention for easier mapping\n        rod_model_output.switch_frame_convention(\n            frame_convention=rod.FrameConvention.Urdf,\n            explicit_frames=True,\n            attach_frames_to_links=True,\n        )\n\n        # Get links and joints from the ROD model\n        links_dict = {link.name: link for link in rod_model_output.links()}\n        joints_dict = {joint.name: joint for joint in rod_model_output.joints()}\n\n        # Iterate over the hardware metadata to update the ROD model\n        hw_metadata = self.kin_dyn_parameters.hw_link_metadata\n        reduced_link_names = set(self.link_names())\n        reduced_joint_names = set(self.joint_names())\n        unit_scale = np.ones(3, dtype=float)\n        link_scale_factors: dict[str, np.ndarray] = {}\n\n        def collect_link_elements(link) -> list:\n            elements_to_update_raw = (link.visual, link.collision)\n            elements_to_update = []\n            for entry in elements_to_update_raw:\n                if entry is None:\n                    continue\n                if isinstance(entry, (list, tuple)):\n                    elements_to_update.extend(e for e in entry if e is not None)\n                else:\n                    elements_to_update.append(entry)\n            return elements_to_update\n\n        def scale_pose_translation(element, scale_vector):\n            if getattr(element, \"pose\", None) is None:\n                return\n            transform = np.array(element.pose.transform(), dtype=float)\n            transform[0:3, 3] = scale_vector * transform[0:3, 3]\n            element.pose = rod.Pose.from_transform(\n                transform=transform,\n                relative_to=element.pose.relative_to,\n            )\n\n        def scale_link_elements(\n            elements_to_update: list,\n            scale_vector: np.ndarray,\n            *,\n            mesh_pose: rod.Pose | None = None,\n            mesh_shape_link: bool = False,\n        ) -> None:\n            for element in elements_to_update:\n                if (\n                    element is None\n                    or not hasattr(element, \"geometry\")\n                    or element.geometry is None\n                ):\n                    continue\n\n                geometry = element.geometry\n                if getattr(geometry, \"box\", None) is not None:\n                    current_size = np.array(geometry.box.size, dtype=float)\n                    geometry.box.size = tuple(\n                        float(v) for v in (current_size * scale_vector).tolist()\n                    )\n                    scale_pose_translation(element, scale_vector)\n                elif getattr(geometry, \"sphere\", None) is not None:\n                    geometry.sphere.radius = float(\n                        float(geometry.sphere.radius) * float(scale_vector[0])\n                    )\n                    scale_pose_translation(element, scale_vector)\n                elif getattr(geometry, \"cylinder\", None) is not None:\n                    geometry.cylinder.radius = float(\n                        float(geometry.cylinder.radius) * float(scale_vector[0])\n                    )\n                    geometry.cylinder.length = float(\n                        float(geometry.cylinder.length) * float(scale_vector[2])\n                    )\n                    scale_pose_translation(element, scale_vector)\n                elif getattr(geometry, \"mesh\", None) is not None:\n                    base_scale = (\n                        np.array(geometry.mesh.scale, dtype=float)\n                        if geometry.mesh.scale is not None\n                        else unit_scale\n                    )\n                    geometry.mesh.scale = tuple(\n                        float(v) for v in (base_scale * scale_vector).tolist()\n                    )\n\n                    # Mesh-parametrized reduced links use metadata to preserve\n                    # the main visual placement in the exported URDF.\n                    if mesh_shape_link and mesh_pose is not None:\n                        element.pose = mesh_pose\n                    else:\n                        scale_pose_translation(element, scale_vector)\n\n        for link_index, link_name in enumerate(self.link_names()):\n            if link_name not in links_dict:\n                continue\n\n            # Skip links with unsupported shapes\n            shape = hw_metadata.link_shape[link_index]\n            if shape == LinkParametrizableShape.Unsupported:\n                logging.debug(f\"Skipping link '{link_name}' with unsupported shape\")\n                continue\n\n            # Update mass and inertia\n            mass = float(self.kin_dyn_parameters.link_parameters.mass[link_index])\n            center_of_mass = np.array(\n                self.kin_dyn_parameters.link_parameters.center_of_mass[link_index]\n            )\n            inertia_tensor = LinkParameters.unflatten_inertia_tensor(\n                self.kin_dyn_parameters.link_parameters.inertia_elements[link_index]\n            )\n\n            links_dict[link_name].inertial.mass = mass\n            L_H_COM = np.eye(4)\n            L_H_COM[0:3, 3] = center_of_mass\n            links_dict[link_name].inertial.pose = rod.Pose.from_transform(\n                transform=L_H_COM,\n                relative_to=links_dict[link_name].inertial.pose.relative_to,\n            )\n            links_dict[link_name].inertial.inertia = rod.Inertia.from_inertia_tensor(\n                inertia_tensor=inertia_tensor, validate=True\n            )\n\n            dims = np.array(hw_metadata.geometry[link_index], dtype=float)\n            elements_to_update = collect_link_elements(links_dict[link_name])\n\n            def find_reference_geometry(attr: str, elements: list = elements_to_update):\n                for element in elements:\n                    if (\n                        element is None\n                        or not hasattr(element, \"geometry\")\n                        or element.geometry is None\n                    ):\n                        continue\n                    geometry = getattr(element.geometry, attr, None)\n                    if geometry is not None:\n                        return geometry\n                return None\n\n            if shape == LinkParametrizableShape.Mesh:\n                scale_vector = dims\n            elif shape == LinkParametrizableShape.Box:\n                ref_box = find_reference_geometry(\"box\")\n                if ref_box is None:\n                    scale_vector = unit_scale\n                else:\n                    base_size = np.array(ref_box.size, dtype=float)\n                    scale_vector = np.divide(\n                        dims,\n                        base_size,\n                        out=np.ones(3, dtype=float),\n                        where=np.abs(base_size) > 1e-12,\n                    )\n            elif shape == LinkParametrizableShape.Sphere:\n                ref_sphere = find_reference_geometry(\"sphere\")\n                base_radius = (\n                    float(ref_sphere.radius) if ref_sphere is not None else 1.0\n                )\n                s = float(dims[0]) / base_radius if abs(base_radius) > 1e-12 else 1.0\n                scale_vector = np.array([s, s, s], dtype=float)\n            elif shape == LinkParametrizableShape.Cylinder:\n                ref_cylinder = find_reference_geometry(\"cylinder\")\n                base_radius = (\n                    float(ref_cylinder.radius) if ref_cylinder is not None else 1.0\n                )\n                base_length = (\n                    float(ref_cylinder.length) if ref_cylinder is not None else 1.0\n                )\n                s_radius = (\n                    float(dims[0]) / base_radius if abs(base_radius) > 1e-12 else 1.0\n                )\n                s_length = (\n                    float(dims[1]) / base_length if abs(base_length) > 1e-12 else 1.0\n                )\n                scale_vector = np.array([s_radius, s_radius, s_length], dtype=float)\n            else:\n                scale_vector = unit_scale\n\n            link_scale_factors[link_name] = np.array(scale_vector, dtype=float)\n\n            element_pose = rod.Pose.from_transform(\n                transform=np.array(hw_metadata.L_H_vis[link_index]),\n                relative_to=link_name,\n            )\n            scale_link_elements(\n                elements_to_update=elements_to_update,\n                scale_vector=scale_vector,\n                mesh_pose=element_pose,\n                mesh_shape_link=(shape == LinkParametrizableShape.Mesh),\n            )\n\n            # Update joint poses\n            for joint_index in range(self.number_of_joints()):\n                if hw_metadata.L_H_pre_mask[link_index, joint_index]:\n                    joint_name = js.joint.idx_to_name(\n                        model=self, joint_index=joint_index\n                    )\n                    if joint_name in joints_dict:\n                        joints_dict[joint_name].pose = rod.Pose.from_transform(\n                            transform=np.array(\n                                hw_metadata.L_H_pre[link_index, joint_index]\n                            ),\n                            relative_to=link_name,\n                        )\n\n        # Propagate link scaling to descendants connected through fixed joints.\n        # These links are typically reduced away in the JaxSim model (e.g. feet\n        # attached to ankles) but still exist in the exported URDF tree.\n        updated = True\n        while updated:\n            updated = False\n            for joint in joints_dict.values():\n                if joint.type != \"fixed\":\n                    continue\n                parent_scale = link_scale_factors.get(joint.parent, None)\n                if parent_scale is None or joint.child in link_scale_factors:\n                    continue\n                link_scale_factors[joint.child] = np.array(parent_scale, dtype=float)\n                updated = True\n\n        # Scale fixed-joint offsets that are not part of the reduced joint set.\n        for joint_name, joint in joints_dict.items():\n            if joint.type != \"fixed\" or joint_name in reduced_joint_names:\n                continue\n            parent_scale = link_scale_factors.get(joint.parent, unit_scale)\n            if np.allclose(parent_scale, unit_scale):\n                continue\n            if joint.pose is None:\n                continue\n            transform = np.array(joint.pose.transform(), dtype=float)\n            transform[0:3, 3] = parent_scale * transform[0:3, 3]\n            joint.pose = rod.Pose.from_transform(\n                transform=transform,\n                relative_to=joint.pose.relative_to,\n            )\n\n        # Apply inherited scaling to non-reduced links (typically descendants\n        # connected via fixed joints).\n        for link_name, scale_vector in link_scale_factors.items():\n            if link_name in reduced_link_names:\n                continue\n            if np.allclose(scale_vector, unit_scale):\n                continue\n            if link_name not in links_dict:\n                continue\n            scale_link_elements(\n                elements_to_update=collect_link_elements(links_dict[link_name]),\n                scale_vector=scale_vector,\n            )\n\n        # Restore continuous joint types for joints with infinite limits\n        # to ensure valid URDF export (continuous joints should not have limits).\n        # Continuous joints are internally represented as revolute with infinite\n        # limits, but must be exported as type=\"continuous\" for valid URDF.\n        for joint in joints_dict.values():\n            # Skip if not a revolute joint with axis and limits\n            if not (\n                joint.type == \"revolute\"\n                and joint.axis is not None\n                and joint.axis.limit is not None\n            ):\n                continue\n\n            lower, upper = joint.axis.limit.lower, joint.axis.limit.upper\n\n            # Check if both limits are infinite (indicating original continuous joint)\n            if not (\n                lower is not None\n                and upper is not None\n                and np.isinf(lower)\n                and lower < 0\n                and np.isinf(upper)\n                and upper > 0\n            ):\n                continue\n\n            # Restore as continuous joint\n            joint.type = \"continuous\"\n\n            # Create a new Limit object with only effort and velocity\n            # (no position limits for continuous joints)\n            joint.axis.limit = rod.Limit(\n                effort=joint.axis.limit.effort,\n                velocity=joint.axis.limit.velocity,\n                lower=None,\n                upper=None,\n            )\n\n        # Export the URDF string\n        urdf_string = UrdfExporter(pretty=True).to_urdf_string(sdf=rod_model_output)\n\n        return urdf_string\n\n    # ==========\n    # Properties\n    # ==========\n\n    def name(self) -> str:\n        \"\"\"\n        Return the name of the model.\n\n        Returns:\n            The name of the model.\n        \"\"\"\n\n        return self.model_name\n\n    def number_of_links(self) -> int:\n        \"\"\"\n        Return the number of links in the model.\n\n        Returns:\n            The number of links in the model.\n\n        Note:\n            The base link is included in the count and its index is always 0.\n        \"\"\"\n\n        return self.kin_dyn_parameters.number_of_links()\n\n    def number_of_joints(self) -> int:\n        \"\"\"\n        Return the number of joints in the model.\n\n        Returns:\n            The number of joints in the model.\n        \"\"\"\n\n        return self.kin_dyn_parameters.number_of_joints()\n\n    def number_of_frames(self) -> int:\n        \"\"\"\n        Return the number of frames in the model.\n\n        Returns:\n            The number of frames in the model.\n\n        \"\"\"\n\n        return self.kin_dyn_parameters.number_of_frames()\n\n    # =================\n    # Base link methods\n    # =================\n\n    def floating_base(self) -> bool:\n        \"\"\"\n        Return whether the model has a floating base.\n\n        Returns:\n            True if the model is floating-base, False otherwise.\n        \"\"\"\n\n        return self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6\n\n    def base_link(self) -> str:\n        \"\"\"\n        Return the name of the base link.\n\n        Returns:\n            The name of the base link.\n\n        Note:\n            By default, the base link is the root of the kinematic tree.\n        \"\"\"\n\n        return self.link_names()[0]\n\n    # =====================\n    # Joint-related methods\n    # =====================\n\n    def dofs(self) -> int:\n        \"\"\"\n        Return the number of degrees of freedom of the model.\n\n        Returns:\n            The number of degrees of freedom of the model.\n\n        Note:\n            We do not yet support multi-DoF joints, therefore this is always equal to\n            the number of joints. In the future, this could be different.\n        \"\"\"\n\n        return sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:])\n\n    def joint_names(self) -> tuple[str, ...]:\n        \"\"\"\n        Return the names of the joints in the model.\n\n        Returns:\n            The names of the joints in the model.\n        \"\"\"\n\n        return self.kin_dyn_parameters.joint_model.joint_names[1:]\n\n    # ====================\n    # Link-related methods\n    # ====================\n\n    def link_names(self) -> tuple[str, ...]:\n        \"\"\"\n        Return the names of the links in the model.\n\n        Returns:\n            The names of the links in the model.\n        \"\"\"\n\n        return self.kin_dyn_parameters.link_names\n\n    # =====================\n    # Frame-related methods\n    # =====================\n\n    def frame_names(self) -> tuple[str, ...]:\n        \"\"\"\n        Return the names of the frames in the model.\n\n        Returns:\n            The names of the frames in the model.\n        \"\"\"\n\n        return self.kin_dyn_parameters.frame_parameters.name\n\n\n# =====================\n# Model post-processing\n# =====================\n\n\ndef reduce(\n    model: JaxSimModel,\n    considered_joints: tuple[str, ...],\n    locked_joint_positions: dict[str, jtp.FloatLike] | None = None,\n) -> JaxSimModel:\n    \"\"\"\n    Reduce the model by lumping together the links connected by removed joints.\n\n    Args:\n        model: The model to reduce.\n        considered_joints: The sequence of joints to consider.\n        locked_joint_positions:\n            A dictionary containing the positions of the joints to be considered\n            in the reduction process. The removed joints in the reduced model\n            will have their position locked to their value of this dictionary.\n            If a joint is not part of the dictionary, its position is set to zero.\n    \"\"\"\n\n    locked_joint_positions = (\n        locked_joint_positions if locked_joint_positions is not None else {}\n    )\n\n    # If locked joints are passed, make sure that they are valid.\n    if not set(locked_joint_positions).issubset(model.joint_names()):\n        new_joints = set(model.joint_names()) - set(locked_joint_positions)\n        raise ValueError(f\"Passed joints not existing in the model: {new_joints}\")\n\n    # Operate on a deep copy of the model description in order to prevent problems\n    # when mutable attributes are updated.\n    intermediate_description = copy.deepcopy(model.description)\n\n    # Update the initial position of the joints.\n    # This is necessary to compute the correct pose of the link pairs connected\n    # to removed joints.\n    for joint_name in set(model.joint_names()) - set(considered_joints):\n        j = intermediate_description.joints_dict[joint_name]\n        with j.mutable_context():\n            j.initial_position = locked_joint_positions.get(joint_name, 0.0)\n\n    # Reduce the model description.\n    # If `considered_joints` contains joints not existing in the model,\n    # the method will raise an exception.\n    reduced_intermediate_description = intermediate_description.reduce(\n        considered_joints=list(considered_joints)\n    )\n\n    # Build the reduced model.\n    reduced_model = JaxSimModel.build(\n        model_description=reduced_intermediate_description,\n        model_name=model.name(),\n        time_step=model.time_step,\n        terrain=model.terrain,\n        contact_model=model.contact_model,\n        contact_params=model.contact_params,\n        actuation_params=model.actuation_params,\n        gravity=model.gravity,\n        integrator=model.integrator,\n        constraints=model.kin_dyn_parameters.constraints,\n    )\n\n    with reduced_model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):\n        # Store the origin of the model, in case downstream logic needs it.\n        reduced_model.built_from = model.built_from\n\n        # Compute the hw parametrization metadata of the reduced model\n        # TODO: move the building of the metadata to KinDynParameters.build()\n        #       and use the model_description instead of model.built_from.\n        reduced_model.kin_dyn_parameters.hw_link_metadata = (\n            reduced_model.compute_hw_link_metadata()\n        )\n\n    return reduced_model\n\n\n# ===================\n# Inertial properties\n# ===================\n\n\n@jax.jit\n@js.common.named_scope\ndef total_mass(model: JaxSimModel) -> jtp.Float:\n    \"\"\"\n    Compute the total mass of the model.\n\n    Args:\n        model: The model to consider.\n\n    Returns:\n        The total mass of the model.\n    \"\"\"\n\n    return model.kin_dyn_parameters.link_parameters.mass.sum().astype(float)\n\n\n@jax.jit\n@js.common.named_scope\ndef link_spatial_inertia_matrices(model: JaxSimModel) -> jtp.Array:\n    \"\"\"\n    Compute the spatial 6D inertia matrices of all links of the model.\n\n    Args:\n        model: The model to consider.\n\n    Returns:\n        A 3D array containing the stacked spatial 6D inertia matrices of the links.\n    \"\"\"\n\n    return jax.vmap(js.kin_dyn_parameters.LinkParameters.spatial_inertia)(\n        model.kin_dyn_parameters.link_parameters\n    )\n\n\n# ==============================\n# Rigid Body Dynamics Algorithms\n# ==============================\n\n\ndef _adjoint_from_rotation_translation(\n    rotation: jtp.Matrix,\n    translation: jtp.Vector,\n) -> jtp.Matrix:\n    zeros = jnp.zeros_like(rotation)\n    top_right = jnp.einsum(\"...ij,...jk->...ik\", Skew.wedge(translation), rotation)\n    return jnp.concatenate(\n        [\n            jnp.concatenate([rotation, top_right], axis=-1),\n            jnp.concatenate([zeros, rotation], axis=-1),\n        ],\n        axis=-2,\n    )\n\n\ndef _inverse_adjoint_from_rotation_translation(\n    rotation: jtp.Matrix,\n    translation: jtp.Vector,\n) -> jtp.Matrix:\n    rotation_t = jnp.swapaxes(rotation, -1, -2)\n    zeros = jnp.zeros_like(rotation_t)\n    top_right = -jnp.einsum(\"...ij,...jk->...ik\", rotation_t, Skew.wedge(translation))\n    return jnp.concatenate(\n        [\n            jnp.concatenate([rotation_t, top_right], axis=-1),\n            jnp.concatenate([zeros, rotation_t], axis=-1),\n        ],\n        axis=-2,\n    )\n\n\ndef _apply_input_representation_to_jacobian(\n    jacobian: jtp.Matrix,\n    base_transform: jtp.Matrix,\n) -> jtp.Matrix:\n    transformed_base = jnp.einsum(\n        \"...ij,jk->...ik\",\n        jacobian[..., :, 0:6],\n        base_transform,\n    )\n    return jnp.concatenate([transformed_base, jacobian[..., :, 6:]], axis=-1)\n\n\ndef _apply_input_representation_derivative_to_jacobian(\n    jacobian: jtp.Matrix,\n    base_transform_derivative: jtp.Matrix,\n) -> jtp.Matrix:\n    transformed_base = jnp.einsum(\n        \"...ij,jk->...ik\",\n        jacobian[..., :, 0:6],\n        base_transform_derivative,\n    )\n    return jnp.concatenate(\n        [transformed_base, jnp.zeros_like(jacobian[..., :, 6:])],\n        axis=-1,\n    )\n\n\ndef _link_jacobian_support_mask(\n    model: JaxSimModel,\n    *,\n    dtype: jnp.dtype,\n) -> jtp.Matrix:\n    κb = model.kin_dyn_parameters.support_body_array_bool\n    return jnp.concatenate(\n        [\n            jnp.ones((model.number_of_links(), 5), dtype=dtype),\n            jnp.asarray(κb, dtype=dtype),\n        ],\n        axis=1,\n    )\n\n\ndef _body_input_transform(\n    data: js.data.JaxSimModelData,\n) -> tuple[jtp.Matrix, jtp.Matrix]:\n    base_transform = data._base_transform\n    base_rotation = base_transform[0:3, 0:3]\n\n    match data.velocity_representation:\n        case VelRepr.Inertial:\n            B_X_I = _inverse_adjoint_from_rotation_translation(\n                rotation=base_rotation,\n                translation=base_transform[0:3, 3],\n            )\n            B_Ẋ_I = -B_X_I @ Cross.vx(data.base_velocity)\n\n        case VelRepr.Body:\n            B_X_I = jnp.eye(6, dtype=base_transform.dtype)\n            B_Ẋ_I = jnp.zeros((6, 6), dtype=base_transform.dtype)\n\n        case VelRepr.Mixed:\n            B_X_I = _inverse_adjoint_from_rotation_translation(\n                rotation=base_rotation,\n                translation=jnp.zeros(3, dtype=base_transform.dtype),\n            )\n            BW_v_BW_B = data.base_velocity.at[0:3].set(\n                jnp.zeros(3, dtype=base_transform.dtype)\n            )\n            B_Ẋ_I = -B_X_I @ Cross.vx(BW_v_BW_B)\n\n        case _:\n            raise ValueError(data.velocity_representation)\n\n    return B_X_I, B_Ẋ_I\n\n\ndef _link_output_adjoint_from_body(\n    data: js.data.JaxSimModelData,\n    B_H_L: jtp.Matrix,\n    *,\n    output_vel_repr: VelRepr,\n) -> jtp.Matrix:\n    base_transform = data._base_transform\n    base_rotation = base_transform[0:3, 0:3]\n    B_R_L = B_H_L[..., 0:3, 0:3]\n    B_p_L = B_H_L[..., 0:3, 3]\n\n    match output_vel_repr:\n        case VelRepr.Inertial:\n            return _adjoint_from_rotation_translation(\n                rotation=base_rotation,\n                translation=base_transform[0:3, 3],\n            )\n\n        case VelRepr.Body:\n            return _inverse_adjoint_from_rotation_translation(\n                rotation=B_R_L,\n                translation=B_p_L,\n            )\n\n        case VelRepr.Mixed:\n            W_p_B_in_LW = -jnp.einsum(\"ij,...j->...i\", base_rotation, B_p_L)\n            W_R_B = jnp.broadcast_to(base_rotation, B_R_L.shape)\n            return _adjoint_from_rotation_translation(\n                rotation=W_R_B,\n                translation=W_p_B_in_LW,\n            )\n\n        case _:\n            raise ValueError(output_vel_repr)\n\n\n@functools.partial(jax.jit, static_argnames=[\"output_vel_repr\"])\ndef generalized_free_floating_jacobian(\n    model: JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    output_vel_repr: VelRepr | None = None,\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the free-floating jacobians of all links.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        output_vel_repr:\n            The output velocity representation of the free-floating jacobians.\n\n    Returns:\n        The `(nL, 6, 6+dofs)` array containing the stacked free-floating\n        jacobians of the links. The first axis is the link index.\n\n    Note:\n        The v-stacked version of the returned Jacobian array together with the\n        flattened 6D forces of the links, are useful to compute the `J.T @ f`\n        product of the multi-body EoM.\n    \"\"\"\n\n    output_vel_repr = (\n        output_vel_repr if output_vel_repr is not None else data.velocity_representation\n    )\n\n    # Compute the doubly-left free-floating full jacobian.\n    B_J_full_WX_B, B_H_L = jaxsim.rbda.jacobian_full_doubly_left(\n        model=model,\n        joint_positions=data.joint_positions,\n    )\n\n    support_mask = _link_jacobian_support_mask(model=model, dtype=B_J_full_WX_B.dtype)\n    B_J_WL_B = support_mask[:, jnp.newaxis, :] * B_J_full_WX_B[jnp.newaxis, ...]\n\n    B_X_I, _ = _body_input_transform(data=data)\n    B_J_WL_I = _apply_input_representation_to_jacobian(\n        jacobian=B_J_WL_B,\n        base_transform=B_X_I,\n    )\n\n    O_X_B = _link_output_adjoint_from_body(\n        data=data,\n        B_H_L=B_H_L,\n        output_vel_repr=output_vel_repr,\n    )\n\n    return jnp.einsum(\"...ij,...jk->...ik\", O_X_B, B_J_WL_I)\n\n\n@functools.partial(jax.jit, static_argnames=[\"output_vel_repr\"])\ndef generalized_free_floating_jacobian_derivative(\n    model: JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    output_vel_repr: VelRepr | None = None,\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the free-floating jacobian derivatives of all links.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        output_vel_repr:\n            The output velocity representation of the free-floating jacobian derivatives.\n\n    Returns:\n        The `(nL, 6, 6+dofs)` array containing the stacked free-floating\n        jacobian derivatives of the links. The first axis is the link index.\n    \"\"\"\n\n    output_vel_repr = (\n        output_vel_repr if output_vel_repr is not None else data.velocity_representation\n    )\n\n    # Compute the derivative of the doubly-left free-floating full jacobian.\n    B_J̇_full_WX_B, B_H_L = jaxsim.rbda.jacobian_derivative_full_doubly_left(\n        model=model,\n        joint_positions=data.joint_positions,\n        joint_velocities=data.joint_velocities,\n    )\n\n    # The derivative of the equation to change the input and output representations\n    # of the Jacobian derivative needs the computation of the plain link Jacobian.\n    B_J_full_WL_B, _ = jaxsim.rbda.jacobian_full_doubly_left(\n        model=model,\n        joint_positions=data.joint_positions,\n    )\n\n    support_mask = _link_jacobian_support_mask(model=model, dtype=B_J̇_full_WX_B.dtype)\n    B_J̇_WL_B = support_mask[:, jnp.newaxis, :] * B_J̇_full_WX_B[jnp.newaxis, ...]\n    B_J_WL_B = support_mask[:, jnp.newaxis, :] * B_J_full_WL_B[jnp.newaxis, ...]\n\n    B_X_I, B_Ẋ_I = _body_input_transform(data=data)\n    B_J_WL_I = _apply_input_representation_to_jacobian(\n        jacobian=B_J_WL_B,\n        base_transform=B_X_I,\n    )\n    B_J̇_WL_input = _apply_input_representation_to_jacobian(\n        jacobian=B_J̇_WL_B,\n        base_transform=B_X_I,\n    )\n    B_J̇_WL_repr = _apply_input_representation_derivative_to_jacobian(\n        jacobian=B_J_WL_B,\n        base_transform_derivative=B_Ẋ_I,\n    )\n\n    B_v_WB = B_X_I @ data.base_velocity\n    B_ν = jnp.concatenate([B_v_WB, data.joint_velocities])\n    B_v_WL = jnp.einsum(\"bij,j->bi\", B_J_WL_B, B_ν)\n\n    O_X_B = _link_output_adjoint_from_body(\n        data=data,\n        B_H_L=B_H_L,\n        output_vel_repr=output_vel_repr,\n    )\n\n    match output_vel_repr:\n        case VelRepr.Inertial:\n            O_Ẋ_B = O_X_B @ Cross.vx(B_v_WB)\n\n        case VelRepr.Body:\n            B_v_B_L = B_v_WL - B_v_WB\n            O_Ẋ_B = -jnp.einsum(\"...ij,...jk->...ik\", O_X_B, Cross.vx(B_v_B_L))\n\n        case VelRepr.Mixed:\n            base_rotation = data._base_transform[0:3, 0:3]\n            B_p_L = B_H_L[..., 0:3, 3]\n            W_p_B_in_LW = -jnp.einsum(\"ij,...j->...i\", base_rotation, B_p_L)\n            W_R_B = jnp.broadcast_to(base_rotation, B_H_L[..., 0:3, 0:3].shape)\n            B_X_LW = _inverse_adjoint_from_rotation_translation(\n                rotation=W_R_B,\n                translation=W_p_B_in_LW,\n            )\n\n            LW_v_WL = jnp.einsum(\"...ij,...j->...i\", O_X_B, B_v_WL)\n            LW_v_W_LW = LW_v_WL.at[..., 3:6].set(jnp.zeros_like(LW_v_WL[..., 3:6]))\n            LW_v_LW_L = LW_v_WL - LW_v_W_LW\n            LW_v_B_LW = LW_v_WL - jnp.einsum(\"...ij,j->...i\", O_X_B, B_v_WB) - LW_v_LW_L\n\n            O_Ẋ_B = -jnp.einsum(\n                \"...ij,...jk->...ik\",\n                O_X_B,\n                Cross.vx(jnp.einsum(\"...ij,...j->...i\", B_X_LW, LW_v_B_LW)),\n            )\n\n        case _:\n            raise ValueError(output_vel_repr)\n\n    O_J̇_WL_I = jnp.einsum(\"...ij,...jk->...ik\", O_Ẋ_B, B_J_WL_I)\n    O_J̇_WL_I += jnp.einsum(\"...ij,...jk->...ik\", O_X_B, B_J̇_WL_input)\n    O_J̇_WL_I += jnp.einsum(\"...ij,...jk->...ik\", O_X_B, B_J̇_WL_repr)\n\n    return O_J̇_WL_I\n\n\n@functools.partial(jax.jit, static_argnames=[\"prefer_aba\"])\ndef forward_dynamics(\n    model: JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    joint_forces: jtp.VectorLike | None = None,\n    link_forces: jtp.MatrixLike | None = None,\n    prefer_aba: float = True,\n) -> tuple[jtp.Vector, jtp.Vector]:\n    \"\"\"\n    Compute the forward dynamics of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        joint_forces:\n            The joint forces to consider as a vector of shape `(dofs,)`.\n        link_forces:\n            The link 6D forces consider as a matrix of shape `(nL, 6)`.\n            The frame in which they are expressed must be `data.velocity_representation`.\n        prefer_aba: Whether to prefer the ABA algorithm over the CRB one.\n\n    Returns:\n        A tuple containing the 6D acceleration in the active representation of the\n        base link and the joint accelerations resulting from the application of the\n        considered joint forces and external forces.\n    \"\"\"\n\n    forward_dynamics_fn = forward_dynamics_aba if prefer_aba else forward_dynamics_crb\n\n    return forward_dynamics_fn(\n        model=model,\n        data=data,\n        joint_forces=joint_forces,\n        link_forces=link_forces,\n    )\n\n\n@functools.partial(jax.jit, static_argnames=(\"parallel\",))\n@js.common.named_scope\ndef forward_dynamics_aba(\n    model: JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    joint_forces: jtp.VectorLike | None = None,\n    link_forces: jtp.MatrixLike | None = None,\n    parallel: bool = False,\n) -> tuple[jtp.Vector, jtp.Vector]:\n    \"\"\"\n    Compute the forward dynamics of the model with the ABA algorithm.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        joint_forces:\n            The joint forces to consider as a vector of shape `(dofs,)`.\n        link_forces:\n            The link 6D forces to consider as a matrix of shape `(nL, 6)`.\n            The frame in which they are expressed must be `data.velocity_representation`.\n        parallel:\n            If ``True``, use the level-parallel ABA implementation that\n            processes independent tree branches simultaneously.\n            Beneficial on GPU or for wide/deep kinematic trees.\n\n    Returns:\n        A tuple containing the 6D acceleration in the active representation of the\n        base link and the joint accelerations resulting from the application of the\n        considered joint forces and external forces.\n    \"\"\"\n\n    # ============\n    # Prepare data\n    # ============\n\n    # Build joint forces, if not provided.\n    τ = (\n        jnp.atleast_1d(joint_forces.squeeze())\n        if joint_forces is not None\n        else jnp.zeros_like(data.joint_positions)\n    )\n\n    # Build link forces, if not provided.\n    f_L = (\n        jnp.atleast_2d(link_forces.squeeze())\n        if link_forces is not None\n        else jnp.zeros((model.number_of_links(), 6))\n    )\n\n    # Create a references object that simplifies converting among representations.\n    references = js.references.JaxSimModelReferences.build(\n        model=model,\n        joint_force_references=τ,\n        link_forces=f_L,\n        data=data,\n        velocity_representation=data.velocity_representation,\n    )\n\n    # Extract the state in inertial-fixed representation.\n    with data.switch_velocity_representation(VelRepr.Inertial):\n        W_p_B = data.base_position\n        W_v_WB = data.base_velocity\n        W_Q_B = data.base_orientation\n        s = data.joint_positions\n        ṡ = data.joint_velocities\n\n    # Extract the inputs in inertial-fixed representation.\n    W_f_L = references._link_forces\n    τ = references._joint_force_references\n\n    # ========================\n    # Compute forward dynamics\n    # ========================\n\n    aba_fn = jaxsim.rbda.aba_parallel if parallel else jaxsim.rbda.aba\n\n    W_v̇_WB, s̈ = aba_fn(\n        model=model,\n        base_position=W_p_B,\n        base_quaternion=W_Q_B,\n        joint_positions=s,\n        base_linear_velocity=W_v_WB[0:3],\n        base_angular_velocity=W_v_WB[3:6],\n        joint_velocities=ṡ,\n        joint_transforms=model.kin_dyn_parameters.joint_transforms(\n            joint_positions=s,\n            base_transform=data.base_transform,\n        ),\n        joint_forces=τ,\n        link_forces=W_f_L,\n        standard_gravity=model.gravity,\n    )\n\n    # =============\n    # Adjust output\n    # =============\n\n    def to_active(\n        W_v̇_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WB: jtp.Vector, W_v_WC: jtp.Vector\n    ) -> jtp.Vector:\n        \"\"\"\n        Convert the inertial-fixed apparent base acceleration W_v̇_WB to\n        another representation C_v̇_WB expressed in a generic frame C.\n        \"\"\"\n\n        # In Mixed representation, we need to include a cross product in ℝ⁶.\n        # In Inertial and Body representations, the cross product is always zero.\n        C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)\n        return C_X_W @ (W_v̇_WB - Cross.vx(W_v_WC) @ W_v_WB)\n\n    match data.velocity_representation:\n        case VelRepr.Inertial:\n            # In this case C=W\n            W_H_C = W_H_W = jnp.eye(4)  # noqa: F841\n            W_v_WC = W_v_WW = jnp.zeros(6)  # noqa: F841\n\n        case VelRepr.Body:\n            # In this case C=B\n            W_H_C = W_H_B = data._base_transform\n            W_v_WC = W_v_WB\n\n        case VelRepr.Mixed:\n            # In this case C=B[W]\n            W_H_B = data._base_transform\n            W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))  # noqa: F841\n            W_ṗ_B = data.base_velocity[0:3]\n            W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)  # noqa: F841\n\n        case _:\n            raise ValueError(data.velocity_representation)\n\n    # We need to convert the derivative of the base velocity to the active\n    # representation. In Mixed representation, this conversion is not a plain\n    # transformation with just X, but it also involves a cross product in ℝ⁶.\n    C_v̇_WB = to_active(\n        W_v̇_WB=W_v̇_WB,\n        W_H_C=W_H_C,\n        W_v_WB=W_v_WB,\n        W_v_WC=W_v_WC,\n    )\n\n    # The ABA algorithm already returns a zero base 6D acceleration for\n    # fixed-based models. However, the to_active function introduces an\n    # additional acceleration component in Mixed representation.\n    # Here below we make sure that the base acceleration is zero.\n    C_v̇_WB = C_v̇_WB if model.floating_base() else jnp.zeros(6)\n\n    return C_v̇_WB.astype(float), s̈.astype(float)\n\n\n@jax.jit\n@js.common.named_scope\ndef forward_dynamics_crb(\n    model: JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    joint_forces: jtp.VectorLike | None = None,\n    link_forces: jtp.MatrixLike | None = None,\n) -> tuple[jtp.Vector, jtp.Vector]:\n    \"\"\"\n    Compute the forward dynamics of the model with the CRB algorithm.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        joint_forces:\n            The joint forces to consider as a vector of shape `(dofs,)`.\n        link_forces:\n            The link 6D forces to consider as a matrix of shape `(nL, 6)`.\n            The frame in which they are expressed must be `data.velocity_representation`.\n\n    Returns:\n        A tuple containing the 6D acceleration in the active representation of the\n        base link and the joint accelerations resulting from the application of the\n        considered joint forces and external forces.\n\n    Note:\n        Compared to ABA, this method could be significantly slower, especially for\n        models with a large number of degrees of freedom.\n    \"\"\"\n\n    # ============\n    # Prepare data\n    # ============\n\n    # Build joint torques if not provided.\n    τ = (\n        jnp.atleast_1d(joint_forces)\n        if joint_forces is not None\n        else jnp.zeros_like(data.joint_positions)\n    )\n\n    # Build external forces if not provided.\n    f = (\n        jnp.atleast_2d(link_forces)\n        if link_forces is not None\n        else jnp.zeros(shape=(model.number_of_links(), 6))\n    )\n\n    # Compute terms of the floating-base EoM.\n    M = free_floating_mass_matrix(model=model, data=data)\n    h = free_floating_bias_forces(model=model, data=data)\n    S = jnp.block([jnp.zeros(shape=(model.dofs(), 6)), jnp.eye(model.dofs())]).T\n    J = generalized_free_floating_jacobian(model=model, data=data)\n\n    # TODO: invert the Mss block exploiting sparsity defined by the parent array λ(i)\n\n    # ========================\n    # Compute forward dynamics\n    # ========================\n\n    if model.floating_base():\n        # l: number of links.\n        # g: generalized coordinates, 6 + number of joints.\n        JTf = jnp.einsum(\"l6g,l6->g\", J, f)\n        ν̇ = jnp.linalg.solve(M, S @ τ - h + JTf)\n\n    else:\n        # l: number of links.\n        # j: number of joints.\n        JTf = jnp.einsum(\"l6j,l6->j\", J[:, :, 6:], f)\n        s̈ = jnp.linalg.solve(M[6:, 6:], τ - h[6:] + JTf)\n\n        v̇_WB = jnp.zeros(6)\n        ν̇ = jnp.hstack([v̇_WB, s̈.squeeze()])\n\n    # =============\n    # Adjust output\n    # =============\n\n    # Extract the base acceleration in the active representation.\n    # Note that this is an apparent acceleration (relevant in Mixed representation),\n    # therefore it cannot be always expressed in different frames with just a\n    # 6D transformation X.\n    v̇_WB = ν̇[0:6].squeeze().astype(float)\n\n    # Extract the joint accelerations.\n    s̈ = jnp.atleast_1d(ν̇[6:].squeeze()).astype(float)\n\n    return v̇_WB, s̈\n\n\n@functools.partial(jax.jit, static_argnames=(\"parallel\",))\n@js.common.named_scope\ndef forward_kinematics(\n    model: JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    parallel: bool = False,\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the forward kinematics of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        parallel: If True, use the level-parallel FK implementation that\n            processes independent tree branches simultaneously.\n\n    Returns:\n        The nL x 4 x 4 array containing the stacked homogeneous transformations\n        of the links. The first axis is the link index.\n    \"\"\"\n\n    fk_fn = (\n        jaxsim.rbda.forward_kinematics_model_parallel\n        if parallel\n        else jaxsim.rbda.forward_kinematics_model\n    )\n\n    # Recompute joint transforms from the model to ensure gradients\n    # flow through model parameters.\n    joint_transforms = model.kin_dyn_parameters.joint_transforms(\n        joint_positions=data.joint_positions,\n        base_transform=data.base_transform,\n    )\n\n    W_H_LL, _ = fk_fn(\n        model=model,\n        base_position=data.base_position,\n        base_quaternion=data.base_quaternion,\n        joint_positions=data.joint_positions,\n        joint_velocities=data.joint_velocities,\n        base_linear_velocity_inertial=data._base_linear_velocity,\n        base_angular_velocity_inertial=data._base_angular_velocity,\n        joint_transforms=joint_transforms,\n    )\n\n    return W_H_LL\n\n\ndef _transform_M_block(M_body: jtp.Matrix, X: jtp.Matrix) -> jtp.Matrix:\n    \"\"\"\n    Apply invTᵀ M_body invT with invT = diag(X, I_n), without forming invT.\n\n    Args:\n        M_body: (6+n, 6+n) mass matrix (inverse) in body representation.\n        X:      (6, 6) adjoint (e.g. B_X_W or B_X_BW).\n\n    Returns:\n        M_repr: (6+n, 6+n) mass matrix (inverse) in the new representation.\n    \"\"\"\n\n    # invTᵀ M invT with invT = diag(X, I):\n    # Mbb' = Xᵀ Mbb X\n    # Mbj' = Xᵀ Mbj\n    # Mjb' = Mjb X\n    # Mjj' = Mjj\n    Mbb_t = X.T @ M_body[:6, :6] @ X\n    Mbj_t = X.T @ M_body[:6, 6:]\n    Mjb_t = M_body[6:, :6] @ X\n    Mjj_t = M_body[6:, 6:]\n\n    top = jnp.concatenate([Mbb_t, Mbj_t], axis=1)\n    bottom = jnp.concatenate([Mjb_t, Mjj_t], axis=1)\n    return jnp.concatenate([top, bottom], axis=0)\n\n\n@jax.jit\n@js.common.named_scope\ndef free_floating_mass_matrix(\n    model: JaxSimModel, data: js.data.JaxSimModelData\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the free-floating mass matrix of the model with the CRBA algorithm.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The free-floating mass matrix of the model.\n    \"\"\"\n\n    M_body = jaxsim.rbda.crba(\n        model=model,\n        joint_positions=data.joint_positions,\n    )\n\n    match data.velocity_representation:\n        case VelRepr.Body:\n            return M_body\n\n        case VelRepr.Inertial:\n            B_X_W = Adjoint.from_transform(transform=data.base_transform, inverse=True)\n\n            return _transform_M_block(M_body, B_X_W)\n\n        case VelRepr.Mixed:\n            BW_H_B = data.base_transform.at[0:3, 3].set(jnp.zeros(3))\n            B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)\n\n            return _transform_M_block(M_body, B_X_BW)\n        case _:\n            raise ValueError(data.velocity_representation)\n\n\n@jax.jit\n@js.common.named_scope\ndef free_floating_mass_matrix_inverse(\n    model: JaxSimModel, data: js.data.JaxSimModelData\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the inverse of the free-floating mass matrix of the model\n    with the CRBA algorithm.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The inverse of the free-floating mass matrix of the model.\n    \"\"\"\n    M_inv_body = jaxsim.rbda.mass_inverse(\n        model=model,\n        joint_transforms=data._joint_transforms,\n    )\n\n    match data.velocity_representation:\n        case VelRepr.Body:\n            return M_inv_body\n        case VelRepr.Inertial:\n            W_X_B = Adjoint.from_transform(transform=data.base_transform)\n\n            return _transform_M_block(M_inv_body, W_X_B.T)\n        case VelRepr.Mixed:\n            B_H_BW = data.base_transform.at[0:3, 3].set(jnp.zeros(3))\n            BW_X_B = Adjoint.from_transform(transform=B_H_BW)\n\n            return _transform_M_block(M_inv_body, BW_X_B.T)\n        case _:\n            raise ValueError(data.velocity_representation)\n\n\n@jax.jit\n@js.common.named_scope\ndef free_floating_coriolis_matrix(\n    model: JaxSimModel, data: js.data.JaxSimModelData\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the free-floating Coriolis matrix of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The free-floating Coriolis matrix of the model.\n\n    Note:\n        This function, contrarily to other quantities of the equations of motion,\n        does not exploit any iterative algorithm. Therefore, the computation of\n        the Coriolis matrix may be much slower than other quantities.\n    \"\"\"\n\n    # We perform all the calculation in body-fixed.\n    # The Coriolis matrix computed in this representation is converted later\n    # to the active representation stored in data.\n    with data.switch_velocity_representation(VelRepr.Body):\n        B_ν = data.generalized_velocity\n\n        # Doubly-left free-floating Jacobian.\n        L_J_WL_B = generalized_free_floating_jacobian(model=model, data=data)\n\n        # Doubly-left free-floating Jacobian derivative.\n        L_J̇_WL_B = generalized_free_floating_jacobian_derivative(\n            model=model, data=data\n        )\n\n    L_M_L = link_spatial_inertia_matrices(model=model)\n\n    # Body-fixed link velocities.\n    # Note: we could have called link.velocity() instead of computing it ourselves,\n    # but since we need the link Jacobians later, we can save a double calculation.\n    L_v_WL = jax.vmap(lambda J: J @ B_ν)(L_J_WL_B)\n\n    # Compute the contribution of each link to the Coriolis matrix.\n    def compute_link_contribution(M, v, J, J̇) -> jtp.Array:\n        return J.T @ ((Cross.vx_star(v) @ M + M @ Cross.vx(v)) @ J + M @ J̇)\n\n    C_B_links = jax.vmap(compute_link_contribution)(\n        L_M_L,\n        L_v_WL,\n        L_J_WL_B,\n        L_J̇_WL_B,\n    )\n\n    # We need to adjust the Coriolis matrix for fixed-base models.\n    # In this case, the base link does not contribute to the matrix, and we need to zero\n    # the off-diagonal terms mapping joint quantities onto the base configuration.\n    if model.floating_base():\n        C_B = C_B_links.sum(axis=0)\n    else:\n        C_B = C_B_links[1:].sum(axis=0)\n        C_B = C_B.at[0:6, 6:].set(0.0)\n        C_B = C_B.at[6:, 0:6].set(0.0)\n\n    # Adjust the representation of the Coriolis matrix.\n    # Refer to https://github.com/traversaro/traversaro-phd-thesis, Section 3.6.\n    match data.velocity_representation:\n        case VelRepr.Body:\n            return C_B\n\n        case VelRepr.Inertial:\n            n = model.dofs()\n            W_H_B = data._base_transform\n            B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True)\n            B_T_W = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(n))\n\n            with data.switch_velocity_representation(VelRepr.Inertial):\n                W_v_WB = data.base_velocity\n                B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)\n\n            B_Ṫ_W = jax.scipy.linalg.block_diag(B_Ẋ_W, jnp.zeros(shape=(n, n)))\n\n            with data.switch_velocity_representation(VelRepr.Body):\n                M = free_floating_mass_matrix(model=model, data=data)\n\n            C = B_T_W.T @ (M @ B_Ṫ_W + C_B @ B_T_W)\n\n            return C\n\n        case VelRepr.Mixed:\n            n = model.dofs()\n            BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3))\n            B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)\n            B_T_BW = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(n))\n\n            with data.switch_velocity_representation(VelRepr.Mixed):\n                BW_v_WB = data.base_velocity\n                BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))\n\n            BW_v_BW_B = BW_v_WB - BW_v_W_BW\n            B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B)\n\n            B_Ṫ_BW = jax.scipy.linalg.block_diag(B_Ẋ_BW, jnp.zeros(shape=(n, n)))\n\n            with data.switch_velocity_representation(VelRepr.Body):\n                M = free_floating_mass_matrix(model=model, data=data)\n\n            C = B_T_BW.T @ (M @ B_Ṫ_BW + C_B @ B_T_BW)\n\n            return C\n\n        case _:\n            raise ValueError(data.velocity_representation)\n\n\n@jax.jit\n@js.common.named_scope\ndef inverse_dynamics(\n    model: JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    joint_accelerations: jtp.VectorLike | None = None,\n    base_acceleration: jtp.VectorLike | None = None,\n    link_forces: jtp.MatrixLike | None = None,\n) -> tuple[jtp.Vector, jtp.Vector]:\n    \"\"\"\n    Compute inverse dynamics with the RNEA algorithm.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        joint_accelerations:\n            The joint accelerations to consider as a vector of shape `(dofs,)`.\n        base_acceleration:\n            The base acceleration to consider as a vector of shape `(6,)`.\n        link_forces:\n            The link 6D forces to consider as a matrix of shape `(nL, 6)`.\n            The frame in which they are expressed must be `data.velocity_representation`.\n\n    Returns:\n        A tuple containing the 6D force in the active representation applied to the\n        base to obtain the considered base acceleration, and the joint forces to apply\n        to obtain the considered joint accelerations.\n    \"\"\"\n\n    # ============\n    # Prepare data\n    # ============\n\n    # Build joint accelerations, if not provided.\n    s̈ = (\n        jnp.atleast_1d(jnp.array(joint_accelerations).squeeze())\n        if joint_accelerations is not None\n        else jnp.zeros_like(data.joint_positions)\n    )\n\n    # Build base acceleration, if not provided.\n    v̇_WB = (\n        jnp.array(base_acceleration).squeeze()\n        if base_acceleration is not None\n        else jnp.zeros(6)\n    )\n\n    # Build link forces, if not provided.\n    f_L = (\n        jnp.atleast_2d(jnp.array(link_forces).squeeze())\n        if link_forces is not None\n        else jnp.zeros(shape=(model.number_of_links(), 6))\n    )\n\n    def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC):\n        \"\"\"\n        Convert the active representation of the base acceleration C_v̇_WB\n        expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.\n        \"\"\"\n\n        W_X_C = Adjoint.from_transform(transform=W_H_C)\n        C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)\n        C_v_WC = C_X_W @ W_v_WC\n\n        # In Mixed representation, we need to include a cross product in ℝ⁶.\n        # In Inertial and Body representations, the cross product is always zero.\n        return W_X_C @ (C_v̇_WB + Cross.vx(C_v_WC) @ C_v_WB)\n\n    match data.velocity_representation:\n        case VelRepr.Inertial:\n            W_H_C = W_H_W = jnp.eye(4)  # noqa: F841\n            W_v_WC = W_v_WW = jnp.zeros(6)  # noqa: F841\n\n        case VelRepr.Body:\n            W_H_C = W_H_B = data._base_transform\n            with data.switch_velocity_representation(VelRepr.Inertial):\n                W_v_WC = W_v_WB = data.base_velocity\n\n        case VelRepr.Mixed:\n            W_H_B = data._base_transform\n            W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))  # noqa: F841\n            W_ṗ_B = data.base_velocity[0:3]\n            W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)  # noqa: F841\n\n        case _:\n            raise ValueError(data.velocity_representation)\n\n    # We need to convert the derivative of the base acceleration to the Inertial\n    # representation. In Mixed representation, this conversion is not a plain\n    # transformation with just X, but it also involves a cross product in ℝ⁶.\n    W_v̇_WB = to_inertial(\n        C_v̇_WB=v̇_WB,\n        W_H_C=W_H_C,\n        C_v_WB=data.base_velocity,\n        W_v_WC=W_v_WC,\n    )\n\n    # Create a references object that simplifies converting among representations.\n    references = js.references.JaxSimModelReferences.build(\n        model=model,\n        data=data,\n        link_forces=f_L,\n        velocity_representation=data.velocity_representation,\n    )\n\n    # Extract the state in inertial-fixed representation.\n    with data.switch_velocity_representation(VelRepr.Inertial):\n        W_p_B = data.base_position\n        W_v_WB = data.base_velocity\n        W_Q_B = data.base_quaternion\n        s = data.joint_positions\n        ṡ = data.joint_velocities\n\n    # Extract the inputs in inertial-fixed representation.\n    W_f_L = references._link_forces\n\n    # ========================\n    # Compute inverse dynamics\n    # ========================\n\n    W_f_B, τ = jaxsim.rbda.rnea(\n        model=model,\n        base_position=W_p_B,\n        base_quaternion=W_Q_B,\n        joint_positions=s,\n        base_linear_velocity=W_v_WB[0:3],\n        base_angular_velocity=W_v_WB[3:6],\n        joint_velocities=ṡ,\n        base_linear_acceleration=W_v̇_WB[0:3],\n        base_angular_acceleration=W_v̇_WB[3:6],\n        joint_accelerations=s̈,\n        joint_transforms=data._joint_transforms,\n        link_forces=W_f_L,\n        standard_gravity=model.gravity,\n    )\n\n    # =============\n    # Adjust output\n    # =============\n\n    # Express W_f_B in the active representation.\n    f_B = js.data.JaxSimModelData.inertial_to_other_representation(\n        array=W_f_B,\n        other_representation=data.velocity_representation,\n        transform=data._base_transform,\n        is_force=True,\n    ).squeeze()\n\n    return f_B.astype(float), τ.astype(float)\n\n\n@jax.jit\n@js.common.named_scope\ndef free_floating_gravity_forces(\n    model: JaxSimModel, data: js.data.JaxSimModelData\n) -> jtp.Vector:\n    r\"\"\"\n    Compute the free-floating gravity forces :math:`g(\\mathbf{q})` of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The free-floating gravity forces of the model.\n    \"\"\"\n\n    # Build a new state with zeroed velocities.\n    data_rnea = js.data.JaxSimModelData.build(\n        model=model,\n        velocity_representation=data.velocity_representation,\n        base_position=data.base_position,\n        base_quaternion=data.base_quaternion,\n        joint_positions=data.joint_positions,\n    )\n\n    return jnp.hstack(\n        inverse_dynamics(\n            model=model,\n            data=data_rnea,\n            # Set zero inputs:\n            joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())),\n            base_acceleration=jnp.zeros(6),\n            link_forces=jnp.zeros(shape=(model.number_of_links(), 6)),\n        )\n    ).astype(float)\n\n\n@jax.jit\n@js.common.named_scope\ndef free_floating_bias_forces(\n    model: JaxSimModel, data: js.data.JaxSimModelData\n) -> jtp.Vector:\n    r\"\"\"\n    Compute the free-floating bias forces :math:`h(\\mathbf{q}, \\boldsymbol{\\nu})`\n    of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The free-floating bias forces of the model.\n    \"\"\"\n\n    # Set the generalized position and generalized velocity.\n    base_linear_velocity, base_angular_velocity = None, None\n    if model.floating_base():\n        base_velocity = data.base_velocity\n        base_linear_velocity = base_velocity[:3]\n        base_angular_velocity = base_velocity[3:]\n\n    data_rnea = js.data.JaxSimModelData.build(\n        model=model,\n        velocity_representation=data.velocity_representation,\n        base_position=data.base_position,\n        base_quaternion=data.base_quaternion,\n        joint_positions=data.joint_positions,\n        joint_velocities=data.joint_velocities,\n        base_linear_velocity=base_linear_velocity,\n        base_angular_velocity=base_angular_velocity,\n    )\n\n    return jnp.hstack(\n        inverse_dynamics(\n            model=model,\n            data=data_rnea,\n            # Set zero inputs:\n            joint_accelerations=jnp.atleast_1d(jnp.zeros(model.dofs())),\n            base_acceleration=jnp.zeros(6),\n            link_forces=jnp.zeros(shape=(model.number_of_links(), 6)),\n        )\n    ).astype(float)\n\n\n# ==========================\n# Other kinematic quantities\n# ==========================\n\n\n@jax.jit\n@js.common.named_scope\ndef locked_spatial_inertia(\n    model: JaxSimModel, data: js.data.JaxSimModelData\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the locked 6D inertia matrix of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The locked 6D inertia matrix of the model.\n    \"\"\"\n\n    return total_momentum_jacobian(model=model, data=data)[:, 0:6]\n\n\n@jax.jit\n@js.common.named_scope\ndef total_momentum(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:\n    \"\"\"\n    Compute the total momentum of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The total momentum of the model in the active velocity representation.\n    \"\"\"\n\n    ν = data.generalized_velocity\n    Jh = total_momentum_jacobian(model=model, data=data)\n\n    return Jh @ ν\n\n\n@functools.partial(jax.jit, static_argnames=[\"output_vel_repr\"])\ndef total_momentum_jacobian(\n    model: JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    output_vel_repr: VelRepr | None = None,\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the jacobian of the total momentum.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        output_vel_repr: The output velocity representation of the jacobian.\n\n    Returns:\n        The jacobian of the total momentum of the model in the active representation.\n    \"\"\"\n\n    output_vel_repr = (\n        output_vel_repr if output_vel_repr is not None else data.velocity_representation\n    )\n\n    if output_vel_repr is data.velocity_representation:\n        return free_floating_mass_matrix(model=model, data=data)[0:6]\n\n    with data.switch_velocity_representation(VelRepr.Body):\n        B_Jh_B = free_floating_mass_matrix(model=model, data=data)[0:6]\n\n    match data.velocity_representation:\n        case VelRepr.Body:\n            B_Jh = B_Jh_B\n\n        case VelRepr.Inertial:\n            B_X_W = Adjoint.from_transform(transform=data._base_transform, inverse=True)\n            B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs()))\n\n        case VelRepr.Mixed:\n            BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3))\n            B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)\n            B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))\n\n        case _:\n            raise ValueError(data.velocity_representation)\n\n    match output_vel_repr:\n        case VelRepr.Body:\n            return B_Jh\n\n        case VelRepr.Inertial:\n            W_H_B = data._base_transform\n            B_Xv_W = Adjoint.from_transform(transform=W_H_B, inverse=True)\n            W_Xf_B = B_Xv_W.T\n            W_Jh = W_Xf_B @ B_Jh\n            return W_Jh\n\n        case VelRepr.Mixed:\n            BW_H_B = data._base_transform.at[0:3, 3].set(jnp.zeros(3))\n            B_Xv_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)\n            BW_Xf_B = B_Xv_BW.T\n            BW_Jh = BW_Xf_B @ B_Jh\n            return BW_Jh\n\n        case _:\n            raise ValueError(output_vel_repr)\n\n\n@jax.jit\n@js.common.named_scope\ndef average_velocity(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Vector:\n    \"\"\"\n    Compute the average velocity of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The average velocity of the model computed in the base frame and expressed\n        in the active representation.\n    \"\"\"\n\n    ν = data.generalized_velocity\n    J = average_velocity_jacobian(model=model, data=data)\n\n    return J @ ν\n\n\n@functools.partial(jax.jit, static_argnames=[\"output_vel_repr\"])\ndef average_velocity_jacobian(\n    model: JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    output_vel_repr: VelRepr | None = None,\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the Jacobian of the average velocity of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        output_vel_repr: The output velocity representation of the jacobian.\n\n    Returns:\n        The Jacobian of the average centroidal velocity of the model in the desired\n        representation.\n    \"\"\"\n\n    output_vel_repr = (\n        output_vel_repr if output_vel_repr is not None else data.velocity_representation\n    )\n\n    # Depending on the velocity representation, the frame G is either G[W] or G[B].\n    G_J = js.com.average_centroidal_velocity_jacobian(model=model, data=data)\n\n    match output_vel_repr:\n        case VelRepr.Inertial:\n            GW_J = G_J\n            W_p_CoM = js.com.com_position(model=model, data=data)\n\n            W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)\n            W_X_GW = Adjoint.from_transform(transform=W_H_GW)\n\n            return W_X_GW @ GW_J\n\n        case VelRepr.Body:\n            GB_J = G_J\n            W_p_B = data.base_position\n            W_p_CoM = js.com.com_position(model=model, data=data)\n            B_R_W = jaxsim.math.Quaternion.to_dcm(data.base_orientation).transpose()\n\n            B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B))\n            B_X_GB = Adjoint.from_transform(transform=B_H_GB)\n\n            return B_X_GB @ GB_J\n\n        case VelRepr.Mixed:\n            GW_J = G_J\n            W_p_B = data.base_position\n            W_p_CoM = js.com.com_position(model=model, data=data)\n\n            BW_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM - W_p_B)\n            BW_X_GW = Adjoint.from_transform(transform=BW_H_GW)\n\n            return BW_X_GW @ GW_J\n\n\n# ========================\n# Other dynamic quantities\n# ========================\n\n\n@jax.jit\n@js.common.named_scope\ndef link_bias_accelerations(\n    model: JaxSimModel,\n    data: js.data.JaxSimModelData,\n) -> jtp.Vector:\n    r\"\"\"\n    Compute the bias accelerations of the links of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The bias accelerations of the links of the model.\n\n    Note:\n        This function computes the component of the total 6D acceleration not due to\n        the joint or base acceleration.\n        It is often called :math:`\\dot{J} \\boldsymbol{\\nu}`.\n    \"\"\"\n\n    # ================================================\n    # Compute the body-fixed zero base 6D acceleration\n    # ================================================\n\n    # Compute the base transform.\n    W_H_B = data._base_transform\n\n    def other_representation_to_inertial(\n        C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector\n    ) -> jtp.Vector:\n        \"\"\"\n        Convert the active representation of the base acceleration C_v̇_WB\n        expressed in a generic frame C to the inertial-fixed representation W_v̇_WB.\n        \"\"\"\n\n        W_X_C = Adjoint.from_transform(transform=W_H_C)\n        C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)\n\n        # In Mixed representation, we need to include a cross product in ℝ⁶.\n        # In Inertial and Body representations, the cross product is always zero.\n        return W_X_C @ (C_v̇_WB + jaxsim.math.Cross.vx(C_X_W @ W_v_WC) @ C_v_WB)\n\n    # Here we initialize a zero 6D acceleration in the active representation, and\n    # convert it to inertial-fixed. This is a useful intermediate representation\n    # because the apparent acceleration W_v̇_WB is equal to the intrinsic acceleration\n    # W_a_WB, and intrinsic accelerations can be expressed in different frames through\n    # a simple C_X_W 6D transform.\n    match data.velocity_representation:\n        case VelRepr.Inertial:\n            W_H_C = W_H_W = jnp.eye(4)  # noqa: F841\n            W_v_WC = W_v_WW = jnp.zeros(6)  # noqa: F841\n            with data.switch_velocity_representation(VelRepr.Inertial):\n                C_v_WB = W_v_WB = data.base_velocity\n\n        case VelRepr.Body:\n            W_H_C = W_H_B\n            with data.switch_velocity_representation(VelRepr.Inertial):\n                W_v_WC = W_v_WB = data.base_velocity  # noqa: F841\n            with data.switch_velocity_representation(VelRepr.Body):\n                C_v_WB = B_v_WB = data.base_velocity\n\n        case VelRepr.Mixed:\n            W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))\n            W_H_C = W_H_BW\n            with data.switch_velocity_representation(VelRepr.Mixed):\n                W_ṗ_B = data.base_velocity[0:3]\n                BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B)\n                W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW)\n                W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW  # noqa: F841\n            with data.switch_velocity_representation(VelRepr.Mixed):\n                C_v_WB = BW_v_WB = data.base_velocity  # noqa: F841\n\n        case _:\n            raise ValueError(data.velocity_representation)\n\n    # Convert a zero 6D acceleration from the active representation to inertial-fixed.\n    W_v̇_WB = other_representation_to_inertial(\n        C_v̇_WB=jnp.zeros(6), C_v_WB=C_v_WB, W_H_C=W_H_C, W_v_WC=W_v_WC\n    )\n\n    # ===================================\n    # Initialize buffers and prepare data\n    # ===================================\n\n    # Get the parent array λ(i).\n    # Note: λ(0) must not be used, it's initialized to -1.\n    λ = model.kin_dyn_parameters.parent_array\n\n    # Compute 6D transforms of the base velocity.\n    B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True)\n\n    # Compute the parent-to-child adjoints and the motion subspaces of the joints.\n    # These transforms define the relative kinematics of the entire model, including\n    # the base transform for both floating-base and fixed-base models.\n    # Ensure cached transforms stay on device when indexed with traced `i`.\n    i_X_λi = jnp.asarray(data._joint_transforms)\n\n    # Extract the joint motion subspaces.\n    S = model.kin_dyn_parameters.motion_subspaces\n\n    # Allocate the buffer to store the body-fixed link velocities.\n    L_v_WL = jnp.zeros(shape=(model.number_of_links(), 6))\n\n    # Store the base velocity.\n    with data.switch_velocity_representation(VelRepr.Body):\n        B_v_WB = data.base_velocity\n        L_v_WL = L_v_WL.at[0].set(B_v_WB)\n\n    # Get the joint velocities.\n    ṡ = data.joint_velocities\n\n    # Allocate the buffer to store the body-fixed link accelerations,\n    # and initialize the base acceleration.\n    L_v̇_WL = jnp.zeros(shape=(model.number_of_links(), 6))\n    L_v̇_WL = L_v̇_WL.at[0].set(B_X_W @ W_v̇_WB)\n\n    # ======================================\n    # Propagate accelerations and velocities\n    # ======================================\n\n    # The computation of the bias forces is similar to the forward pass of RNEA,\n    # this time with zero base and joint accelerations. Furthermore, here we do\n    # not remove gravity during the propagation.\n\n    # Initialize the loop.\n    Carry = tuple[jtp.Matrix, jtp.Matrix]\n    carry0: Carry = (L_v_WL, L_v̇_WL)\n\n    def propagate_accelerations(carry: Carry, i: jtp.Int) -> tuple[Carry, None]:\n        # Initialize index and unpack the carry.\n        ii = i - 1\n        v, a = carry\n\n        # Get the motion subspace of the joint.\n        Si = S[i].squeeze()\n\n        # Project the joint velocity into its motion subspace.\n        vJ = Si * ṡ[ii]\n\n        # Propagate the link body-fixed velocity.\n        v_i = i_X_λi[i] @ v[λ[i]] + vJ\n        v = v.at[i].set(v_i)\n\n        # Propagate the link body-fixed acceleration considering zero joint acceleration.\n        s̈ = 0.0\n        a_i = i_X_λi[i] @ a[λ[i]] + Si * s̈ + jaxsim.math.Cross.vx(v[i]) @ vJ\n        a = a.at[i].set(a_i)\n\n        return (v, a), None\n\n    # Compute the body-fixed velocity and body-fixed apparent acceleration of the links.\n    (L_v_WL, L_v̇_WL), _ = (\n        jax.lax.scan(\n            f=propagate_accelerations,\n            init=carry0,\n            xs=jnp.arange(start=1, stop=model.number_of_links()),\n        )\n        if model.number_of_links() > 1\n        else [(L_v_WL, L_v̇_WL), None]\n    )\n\n    # ===================================================================\n    # Convert the body-fixed 6D acceleration to the active representation\n    # ===================================================================\n\n    def body_to_other_representation(\n        L_v̇_WL: jtp.Vector, L_v_WL: jtp.Vector, C_H_L: jtp.Matrix, L_v_CL: jtp.Vector\n    ) -> jtp.Vector:\n        \"\"\"\n        Convert the body-fixed apparent acceleration L_v̇_WL to\n        another representation C_v̇_WL expressed in a generic frame C.\n        \"\"\"\n\n        # In Mixed representation, we need to include a cross product in ℝ⁶.\n        # In Inertial and Body representations, the cross product is always zero.\n        C_X_L = jaxsim.math.Adjoint.from_transform(transform=C_H_L)\n        return C_X_L @ (L_v̇_WL + jaxsim.math.Cross.vx(L_v_CL) @ L_v_WL)\n\n    match data.velocity_representation:\n        case VelRepr.Body:\n            C_H_L = L_H_L = jnp.stack(  # noqa: F841\n                [jnp.eye(4)] * model.number_of_links()\n            )\n            L_v_CL = L_v_LL = jnp.zeros(  # noqa: F841\n                shape=(model.number_of_links(), 6)\n            )\n\n        case VelRepr.Inertial:\n            C_H_L = W_H_L = data._link_transforms\n            L_v_CL = L_v_WL\n\n        case VelRepr.Mixed:\n            W_H_L = data._link_transforms\n            LW_H_L = jax.vmap(lambda W_H_L: W_H_L.at[0:3, 3].set(jnp.zeros(3)))(W_H_L)\n            C_H_L = LW_H_L\n            L_v_CL = L_v_LW_L = jax.vmap(  # noqa: F841\n                lambda v: v.at[0:3].set(jnp.zeros(3))\n            )(L_v_WL)\n\n        case _:\n            raise ValueError(data.velocity_representation)\n\n    # Convert from body-fixed to the active representation.\n    O_v̇_WL = jax.vmap(body_to_other_representation)(\n        L_v̇_WL=L_v̇_WL, L_v_WL=L_v_WL, C_H_L=C_H_L, L_v_CL=L_v_CL\n    )\n\n    return O_v̇_WL\n\n\n@jax.jit\ndef joint_transforms(\n    model: JaxSimModel, joint_positions: jtp.VectorLike, base_transform: jtp.MatrixLike\n) -> jtp.Array:\n    r\"\"\"\n    Return the transforms of the joints.\n\n    Args:\n        model: The model to consider.\n        joint_positions: The joint positions.\n        base_transform: The homogeneous matrix defining the base pose.\n\n    Returns:\n        The stacked transforms\n        :math:`{}^{i} \\mathbf{H}_{\\lambda(i)}(s)`\n        of each joint.\n    \"\"\"\n\n    return model.kin_dyn_parameters.joint_transforms(\n        joint_positions=joint_positions,\n        base_transform=base_transform,\n    )\n\n\n# ======\n# Energy\n# ======\n\n\n@jax.jit\n@js.common.named_scope\ndef mechanical_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:\n    \"\"\"\n    Compute the mechanical energy of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The mechanical energy of the model.\n    \"\"\"\n\n    K = kinetic_energy(model=model, data=data)\n    U = potential_energy(model=model, data=data)\n\n    return (K + U).astype(float)\n\n\n@jax.jit\n@js.common.named_scope\ndef kinetic_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:\n    \"\"\"\n    Compute the kinetic energy of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The kinetic energy of the model.\n    \"\"\"\n\n    with data.switch_velocity_representation(velocity_representation=VelRepr.Body):\n        B_ν = data.generalized_velocity\n        M_B = free_floating_mass_matrix(model=model, data=data)\n\n    K = 0.5 * B_ν.T @ M_B @ B_ν\n    return K.squeeze().astype(float)\n\n\n@jax.jit\n@js.common.named_scope\ndef potential_energy(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp.Float:\n    \"\"\"\n    Compute the potential energy of the model.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n\n    Returns:\n        The potential energy of the model.\n    \"\"\"\n\n    m = total_mass(model=model)\n    W_p̃_CoM = jnp.hstack([js.com.com_position(model=model, data=data), 1])\n    return jnp.sum((m * W_p̃_CoM)[2] * model.gravity)\n\n\n# ===================\n# Hw parametrization\n# ===================\n\n\n@jax.jit\n@js.common.named_scope\ndef update_hw_parameters(\n    model: JaxSimModel, scaling_factors: ScalingFactors\n) -> JaxSimModel:\n    \"\"\"\n    Update the hardware parameters of the model by scaling the parameters of the links.\n\n    This function applies scaling factors to the hardware metadata of the links,\n    updating their shape, dimensions, density, and other related parameters. It\n    recalculates the mass and inertia tensors of the links based on the updated\n    metadata and adjusts the joint model transforms accordingly.\n\n    Args:\n        model: The JaxSimModel object to update.\n        scaling_factors: A ScalingFactors object containing scaling factors for\n                         dimensions and density of the links.\n\n    Returns:\n        The updated JaxSimModel object with modified hardware parameters.\n    \"\"\"\n\n    kin_dyn_params: KinDynParameters = model.kin_dyn_parameters\n    link_parameters: LinkParameters = kin_dyn_params.link_parameters\n    hw_link_metadata: HwLinkMetadata = kin_dyn_params.hw_link_metadata\n\n    has_joints = model.number_of_joints() > 0\n\n    def apply_scaling_single_link(\n        link_shape,\n        geometry,\n        density,\n        L_H_G,\n        L_H_vis,\n        L_H_pre,\n        L_H_pre_mask,\n        scaling_dims,\n        scaling_density,\n    ):\n        \"\"\"Apply scaling to a single link's numerical data.\"\"\"\n\n        def scale_supported(_):\n            shape_indices_map = jnp.array([[0, 1, 2], [0, 0, 1], [0, 0, 0], [0, 1, 2]])\n            per_link_indices = shape_indices_map[link_shape]\n            scale_vector = scaling_dims[per_link_indices]\n\n            # Update kinematics\n            G_H_L = jaxsim.math.Transform.inverse(L_H_G)\n            G_H_vis = G_H_L @ L_H_vis\n            G_H̅_vis = G_H_vis.at[:3, 3].set(scale_vector * G_H_vis[:3, 3])\n            L_H̅_G = L_H_G.at[:3, 3].set(scale_vector * L_H_G[:3, 3])\n            L_H̅_vis = L_H̅_G @ G_H̅_vis\n\n            # Update shape parameters\n            updated_geom = geometry * scaling_dims\n            updated_dens = density * scaling_density\n\n            return updated_geom, updated_dens, L_H̅_G, L_H̅_vis, scale_vector\n\n        def scale_unsupported(_):\n            return (\n                geometry,\n                density,\n                L_H_G,\n                L_H_vis,\n                jnp.ones_like(scaling_dims),\n            )\n\n        return jax.lax.cond(\n            link_shape == LinkParametrizableShape.Unsupported,\n            scale_unsupported,\n            scale_supported,\n            operand=None,\n        )\n\n    # Vmap over all links for basic scaling\n    (\n        updated_geometry,\n        updated_density,\n        updated_L_H_G,\n        updated_L_H_vis,\n        scale_vectors,\n    ) = jax.vmap(apply_scaling_single_link)(\n        hw_link_metadata.link_shape,\n        hw_link_metadata.geometry,\n        hw_link_metadata.density,\n        hw_link_metadata.L_H_G,\n        hw_link_metadata.L_H_vis,\n        hw_link_metadata.L_H_pre,\n        hw_link_metadata.L_H_pre_mask,\n        scaling_factors.dims,\n        scaling_factors.density,\n    )\n\n    # Handle joint transforms separately, only if model has joints\n    def transform_all_joints(operands):\n        \"\"\"Transform all joint poses across all links.\"\"\"\n        original_L_H_G, updated_L_H_G, scale_vectors, L_H_pre, L_H_pre_mask = operands\n\n        # Vectorized transformation: (n_links, n_joints, 4, 4)\n        # Express joint transforms in the original CoM frames.\n        # Using the already-scaled L_H_G here introduces a second implicit\n        # scaling term and distorts kinematic chain proportions.\n        G_H_L_all = jax.vmap(jaxsim.math.Transform.inverse)(\n            original_L_H_G\n        )  # (n_links, 4, 4)\n\n        # Use batch matrix multiply with broadcasting\n        # G_H_L_all: (n_links, 4, 4) -> (n_links, 1, 4, 4)\n        # L_H_pre: (n_links, n_joints, 4, 4)\n        # Result: (n_links, n_joints, 4, 4)\n        G_H_pre = G_H_L_all[:, None, :, :] @ L_H_pre\n\n        # Scale translation components\n        G_H̅_pre = G_H_pre.at[:, :, :3, 3].set(\n            jnp.where(\n                L_H_pre_mask[:, :, None],\n                scale_vectors[:, None, :] * G_H_pre[:, :, :3, 3],\n                G_H_pre[:, :, :3, 3],\n            )\n        )\n\n        # Transform back to link frames\n        # updated_L_H_G: (n_links, 4, 4) -> (n_links, 1, 4, 4)\n        # G_H̅_pre: (n_links, n_joints, 4, 4)\n        # Result: (n_links, n_joints, 4, 4)\n        return updated_L_H_G[:, None, :, :] @ G_H̅_pre\n\n    updated_L_H_pre = jax.lax.cond(\n        has_joints,\n        transform_all_joints,\n        lambda operands: operands[3],  # Return L_H_pre unchanged\n        operand=(\n            hw_link_metadata.L_H_G,\n            updated_L_H_G,\n            scale_vectors,\n            hw_link_metadata.L_H_pre,\n            hw_link_metadata.L_H_pre_mask,\n        ),\n    )\n\n    # Create updated HwLinkMetadata\n    updated_hw_link_metadata = hw_link_metadata.replace(\n        geometry=updated_geometry,\n        density=updated_density,\n        L_H_G=updated_L_H_G,\n        L_H_vis=updated_L_H_vis,\n        L_H_pre=updated_L_H_pre,\n    )\n\n    # Compute mass and inertia once and unpack the results\n    m_updated, I_com_updated = HwLinkMetadata.compute_mass_and_inertia(\n        updated_hw_link_metadata\n    )\n\n    # Rotate the inertia tensor at CoM with the link orientation, and store\n    # it in KynDynParameters.\n    I_L_updated = jax.vmap(\n        lambda metadata, I_com: metadata.L_H_G[:3, :3]\n        @ I_com\n        @ metadata.L_H_G[:3, :3].T\n    )(updated_hw_link_metadata, I_com_updated)\n\n    # Update link parameters\n    updated_link_parameters = link_parameters.replace(\n        mass=m_updated,\n        inertia_elements=jax.vmap(LinkParameters.flatten_inertia_tensor)(I_L_updated),\n        center_of_mass=jax.vmap(lambda metadata: metadata.L_H_G[:3, 3])(\n            updated_hw_link_metadata\n        ),\n    )\n\n    if kin_dyn_params.contact_parameters.body:\n        # Compute the contact parameters\n        points = HwLinkMetadata.compute_contact_points(\n            original_contact_params=kin_dyn_params.contact_parameters,\n            link_shapes=updated_hw_link_metadata.link_shape,\n            original_com_positions=link_parameters.center_of_mass,\n            updated_com_positions=updated_link_parameters.center_of_mass,\n            scaling_factors=scaling_factors,\n        )\n\n        # Update contact parameters\n        updated_contact_parameters = kin_dyn_params.contact_parameters.replace(\n            point=points\n        )\n    else:\n        updated_contact_parameters = kin_dyn_params.contact_parameters\n\n    # Update joint model transforms (λ_H_pre)\n    def update_λ_H_pre(joint_index):\n        # Extract the transforms and masks for the current joint index across all links\n        L_H_pre_for_joint = updated_hw_link_metadata.L_H_pre[:, joint_index]\n        L_H_pre_mask_for_joint = updated_hw_link_metadata.L_H_pre_mask[:, joint_index]\n\n        # Select the first valid transform (if any) using the mask\n        first_valid_index = jnp.argmax(L_H_pre_mask_for_joint)\n        selected_transform = L_H_pre_for_joint[first_valid_index]\n\n        # Check if any valid transform exists\n        has_valid_transform = L_H_pre_mask_for_joint.any()\n\n        # Fallback to the original λ_H_pre if no valid transform exists\n        fallback_transform = kin_dyn_params.joint_model.λ_H_pre[joint_index + 1]\n\n        # Return the selected transform or fallback\n        return jnp.where(has_valid_transform, selected_transform, fallback_transform)\n\n    if has_joints:\n        # Apply the update function to all joint indices\n        updated_λ_H_pre = jax.vmap(update_λ_H_pre)(\n            jnp.arange(kin_dyn_params.number_of_joints())\n        )\n\n        # NOTE: λ_H_pre should be of len (1+n_joints) with the 0-th element equal\n        # to identity to represent the world-to-base tree transform. See JointModel class\n        updated_λ_H_pre_with_base = jnp.concatenate(\n            (jnp.eye(4).reshape(1, 4, 4), updated_λ_H_pre), axis=0\n        )\n\n        # Replace the joint model with the updated transforms\n        updated_joint_model = kin_dyn_params.joint_model.replace(\n            λ_H_pre=updated_λ_H_pre_with_base\n        )\n\n    else:\n        # If there are no joints, we can just use the identity transform\n        updated_joint_model = kin_dyn_params.joint_model\n\n    # Replace the kin_dyn_parameters with updated values\n    updated_kin_dyn_params = kin_dyn_params.replace(\n        link_parameters=updated_link_parameters,\n        contact_parameters=updated_contact_parameters,\n        hw_link_metadata=updated_hw_link_metadata,\n        joint_model=updated_joint_model,\n    )\n\n    # Return the updated model\n    return model.replace(kin_dyn_parameters=updated_kin_dyn_params)\n\n\n# ==========\n# Simulation\n# ==========\n\n\n@jax.jit\n@js.common.named_scope\ndef step(\n    model: JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    link_forces: jtp.MatrixLike | None = None,\n    joint_force_references: jtp.VectorLike | None = None,\n) -> js.data.JaxSimModelData:\n    \"\"\"\n    Perform a simulation step.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        dt: The time step to consider. If not specified, it is read from the model.\n        link_forces:\n            The 6D forces to apply to the links expressed in same representation of data.\n        joint_force_references: The joint force references to consider.\n\n    Returns:\n        The new data of the model after the simulation step.\n\n    Note:\n        In order to reduce the occurrences of frame conversions performed internally,\n        it is recommended to use inertial-fixed velocity representation. This can be\n        particularly useful for automatically differentiated logic.\n    \"\"\"\n\n    # TODO: some contact models here may want to perform a dynamic filtering of\n    # the enabled collidable points\n\n    # Extract the inputs\n    O_f_L_external = jnp.atleast_2d(\n        jnp.array(link_forces, dtype=float).squeeze()\n        if link_forces is not None\n        else jnp.zeros((model.number_of_links(), 6))\n    )\n\n    # Get the external forces in inertial-fixed representation.\n    W_f_L_external = js.data.JaxSimModelData.other_representation_to_inertial(\n        O_f_L_external,\n        other_representation=data.velocity_representation,\n        transform=data._link_transforms,\n        is_force=True,\n    )\n\n    τ_references = jnp.atleast_1d(\n        jnp.array(joint_force_references, dtype=float).squeeze()\n        if joint_force_references is not None\n        else jnp.zeros(model.dofs())\n    )\n\n    # ================================\n    # Compute the total joint torques\n    # ================================\n\n    τ_total = js.actuation_model.compute_resultant_torques(\n        model, data, joint_force_references=τ_references\n    )\n\n    # =============================\n    # Advance the simulation state\n    # =============================\n\n    from .integrators import _INTEGRATORS_MAP\n\n    integrator_fn = _INTEGRATORS_MAP[model.integrator]\n\n    data_tf = integrator_fn(\n        model=model,\n        data=data,\n        link_forces=W_f_L_external,\n        joint_torques=τ_total,\n    )\n\n    data_tf = model.contact_model.update_velocity_after_impact(\n        model=model, data=data_tf\n    )\n\n    return data_tf\n"
  },
  {
    "path": "src/jaxsim/api/ode.py",
    "content": "import jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim.math import Quaternion, Skew\nfrom jaxsim.rbda.kinematic_constraints import compute_constraint_wrenches\n\nfrom .common import VelRepr\n\n# ==================================\n# Functions defining system dynamics\n# ==================================\n\n\ndef system_acceleration(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    link_forces: jtp.MatrixLike | None = None,\n    joint_torques: jtp.VectorLike | None = None,\n) -> tuple[jtp.Vector, jtp.Vector, dict[str, jtp.PyTree]]:\n    \"\"\"\n    Compute the system acceleration in the active representation.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        link_forces:\n            The 6D forces to apply to the links expressed in the same\n            velocity representation of data.\n        joint_torques: The joint torques applied to the joints.\n\n    Returns:\n        A tuple containing the base 6D acceleration in the active representation,\n        the joint accelerations, and the contact state.\n    \"\"\"\n\n    # ====================\n    # Validate input data\n    # ====================\n\n    # Build link forces if not provided.\n    f_L = (\n        jnp.atleast_2d(link_forces.squeeze())\n        if link_forces is not None\n        else jnp.zeros((model.number_of_links(), 6))\n    ).astype(float)\n\n    # ======================\n    # Compute contact forces\n    # ======================\n\n    W_f_L_terrain = jnp.zeros_like(f_L)\n    contact_state = data.contact_state\n\n    if len(model.kin_dyn_parameters.contact_parameters.body) > 0:\n\n        # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact\n        # with the terrain.\n        W_f_L_terrain, contact_state_derivative = js.contact.link_contact_forces(\n            model=model,\n            data=data,\n            link_forces=f_L,\n            joint_torques=joint_torques,\n        )\n\n        # Update the contact state data. This is necessary only for the contact models\n        # that require propagation and integration of contact state.\n        contact_state = model.contact_model.update_contact_state(\n            contact_state_derivative\n        )\n\n    # ==================================\n    # Compute kinematic constraint forces\n    # ==================================\n\n    # Sum up all the forces: external + contact\n    W_f_L_total = f_L + W_f_L_terrain\n\n    # Compute the 6D forces W_f ∈ ℝ^{n_constraints × 2 × 6} applied to links due to\n    # kinematic constraints.\n    W_f_L_constraints = compute_constraint_wrenches(\n        model=model,\n        data=data,\n        link_forces_inertial=W_f_L_total,\n        joint_force_references=joint_torques,\n    )\n\n    # Apply constraint forces to the corresponding links\n    if W_f_L_constraints.shape[0] > 0:\n        # Get the constraint map from the model's kinematic parameters\n        constraint_map = model.kin_dyn_parameters.constraints\n\n        if constraint_map is not None:\n            # Stack the parent link indices for both sides of each constraint\n            parent_indices_flat = jnp.concatenate(\n                [constraint_map.parent_link_idxs_1, constraint_map.parent_link_idxs_2],\n            )\n\n            # Flatten the constraint wrenches to match the flattened parent indices\n            constraint_wrenches_flat = W_f_L_constraints.reshape(-1, 6)\n\n            # Apply constraint wrenches using scatter_add for better performance\n            W_f_L_total = W_f_L_total.at[parent_indices_flat].add(\n                constraint_wrenches_flat\n            )\n\n    # Store the link forces in a references object.\n    references = js.references.JaxSimModelReferences.build(\n        model=model,\n        data=data,\n        velocity_representation=data.velocity_representation,\n        link_forces=W_f_L_total,\n    )\n\n    # Compute forward dynamics.\n    #\n    # - Joint accelerations: s̈ ∈ ℝⁿ\n    # - Base acceleration: v̇_WB ∈ ℝ⁶\n    #\n    # Note that ABA returns the base acceleration in the velocity representation\n    # stored in the `data` object.\n    v̇_WB, s̈ = js.model.forward_dynamics_aba(\n        model=model,\n        data=data,\n        joint_forces=joint_torques,\n        link_forces=references.link_forces(model=model, data=data),\n    )\n\n    return v̇_WB, s̈, contact_state\n\n\n@jax.jit\n@js.common.named_scope\ndef system_position_dynamics(\n    data: js.data.JaxSimModelData,\n    baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,\n) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:\n    r\"\"\"\n    Compute the dynamics of the system position.\n\n    Args:\n        data: The data of the considered model.\n        baumgarte_quaternion_regularization:\n            The Baumgarte regularization coefficient for adjusting the quaternion norm.\n\n    Returns:\n        A tuple containing the derivative of the base position, the derivative of the\n        base quaternion, and the derivative of the joint positions.\n\n    Note:\n        In inertial-fixed representation, the linear component of the base velocity is not\n        the derivative of the base position. In fact, the base velocity is defined as:\n        :math:`{} ^W v_{W, B} = \\begin{bmatrix} {} ^W \\dot{p}_B S({} ^W \\omega_{W, B}) {} ^W p _B\\\\ {} ^W \\omega_{W, B} \\end{bmatrix}`.\n        Where :math:`S(\\cdot)` is the skew-symmetric matrix operator.\n    \"\"\"\n\n    ṡ = data.joint_velocities\n    W_Q_B = data.base_orientation\n    W_ω_WB = data.base_velocity[3:6]\n    W_ṗ_B = data.base_velocity[0:3] + Skew.wedge(W_ω_WB) @ data.base_position\n\n    W_Q̇_B = Quaternion.derivative(\n        quaternion=W_Q_B,\n        omega=W_ω_WB,\n        omega_in_body_fixed=False,\n        K=baumgarte_quaternion_regularization,\n    ).squeeze()\n\n    return W_ṗ_B, W_Q̇_B, ṡ\n\n\n@jax.jit\n@js.common.named_scope\ndef system_dynamics(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    link_forces: jtp.Vector | None = None,\n    joint_torques: jtp.Vector | None = None,\n    baumgarte_quaternion_regularization: jtp.FloatLike = 1.0,\n) -> dict[str, jtp.Vector]:\n    \"\"\"\n    Compute the dynamics of the system.\n\n    Args:\n        model: The model to consider.\n        data: The data of the considered model.\n        link_forces:\n            The 6D forces to apply to the links expressed in the frame corresponding to\n            the velocity representation of `data`.\n        joint_torques: The joint torques acting on the joints.\n        baumgarte_quaternion_regularization:\n            The Baumgarte regularization coefficient used to adjust the norm of the\n            quaternion (only used in integrators not operating on the SO(3) manifold).\n\n    Returns:\n        A dictionary containing the derivatives of the base position, the base quaternion,\n        the joint positions, the base linear velocity, the base angular velocity, and the\n        joint velocities.\n    \"\"\"\n\n    with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):\n        W_v̇_WB, s̈, contact_state_derivative = system_acceleration(\n            model=model,\n            data=data,\n            joint_torques=joint_torques,\n            link_forces=link_forces,\n        )\n\n        W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(\n            data=data,\n            baumgarte_quaternion_regularization=baumgarte_quaternion_regularization,\n        )\n\n    return dict(\n        base_position=W_ṗ_B,\n        base_quaternion=W_Q̇_B,\n        joint_positions=ṡ,\n        base_linear_velocity=W_v̇_WB[0:3],\n        base_angular_velocity=W_v̇_WB[3:6],\n        joint_velocities=s̈,\n        contact_state=contact_state_derivative,\n    )\n"
  },
  {
    "path": "src/jaxsim/api/references.py",
    "content": "from __future__ import annotations\n\nimport functools\n\nimport jax\nimport jax.numpy as jnp\nimport jax_dataclasses\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim import exceptions\nfrom jaxsim.utils.tracing import not_tracing\n\nfrom .common import VelRepr\n\ntry:\n    from typing import Self\nexcept ImportError:\n    from typing_extensions import Self\n\n\n@jax_dataclasses.pytree_dataclass\nclass JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):\n    \"\"\"\n    Class containing the references for a `JaxSimModel` object.\n\n    Attributes:\n        _link_forces: The link 6D forces in inertial-fixed representation.\n        _joint_force_references: The joint force references.\n    \"\"\"\n\n    _link_forces: jtp.Matrix\n    _joint_force_references: jtp.Vector\n\n    @staticmethod\n    def zero(\n        model: js.model.JaxSimModel,\n        data: js.data.JaxSimModelData | None = None,\n        velocity_representation: VelRepr = VelRepr.Inertial,\n    ) -> JaxSimModelReferences:\n        \"\"\"\n        Create a `JaxSimModelReferences` object with zero references.\n\n        Args:\n            model: The model for which to create the zero references.\n            data:\n                The data of the model, only needed if the velocity representation is\n                not inertial-fixed.\n            velocity_representation: The velocity representation to use.\n\n        Returns:\n            A `JaxSimModelReferences` object with zero state.\n        \"\"\"\n\n        return JaxSimModelReferences.build(\n            model=model, data=data, velocity_representation=velocity_representation\n        )\n\n    @staticmethod\n    def build(\n        model: js.model.JaxSimModel,\n        joint_force_references: jtp.VectorLike | None = None,\n        link_forces: jtp.MatrixLike | None = None,\n        data: js.data.JaxSimModelData | None = None,\n        velocity_representation: VelRepr | None = None,\n    ) -> JaxSimModelReferences:\n        \"\"\"\n        Create a `JaxSimModelReferences` object with the given references.\n\n        Args:\n            model: The model for which to create the state.\n            joint_force_references: The joint force references.\n            link_forces: The link 6D forces in the desired representation.\n            data:\n                The data of the model, only needed if the velocity representation is\n                not inertial-fixed.\n            velocity_representation: The velocity representation to use.\n\n        Returns:\n            A `JaxSimModelReferences` object with the given references.\n        \"\"\"\n\n        # Create or adjust joint force references.\n        joint_force_references = jnp.atleast_1d(\n            jnp.array(joint_force_references, dtype=float).squeeze()\n            if joint_force_references is not None\n            else jnp.zeros(model.dofs())\n        ).astype(float)\n\n        # Create or adjust link forces.\n        f_L = jnp.atleast_2d(\n            jnp.array(link_forces, dtype=float).squeeze()\n            if link_forces is not None\n            else jnp.zeros((model.number_of_links(), 6))\n        ).astype(float)\n\n        # Select the velocity representation.\n        velocity_representation = (\n            velocity_representation\n            if velocity_representation is not None\n            else getattr(data, \"velocity_representation\", VelRepr.Inertial)\n        )\n\n        # Create a zero references object.\n        references = JaxSimModelReferences(\n            _link_forces=f_L,\n            _joint_force_references=joint_force_references,\n            velocity_representation=velocity_representation,\n        )\n\n        # If the velocity representation is inertial-fixed, we can return\n        # the references directly, as we store the link forces in this frame.\n        if velocity_representation is VelRepr.Inertial:\n            return references\n\n        # Store the joint force references.\n        references = references.set_joint_force_references(\n            forces=joint_force_references,\n            model=model,\n            joint_names=model.joint_names(),\n        )\n\n        # Apply the link forces.\n        references = references.apply_link_forces(\n            forces=f_L,\n            model=model,\n            data=data,\n            link_names=model.link_names(),\n            additive=False,\n        )\n\n        return references\n\n    def valid(self, model: js.model.JaxSimModel | None = None) -> bool:\n        \"\"\"\n        Check if the current references are valid for the given model.\n\n        Args:\n            model: The model to check against.\n\n        Returns:\n            `True` if the current references are valid for the given model,\n            `False` otherwise.\n        \"\"\"\n\n        if model is None:\n            return True\n\n        shape = self._joint_force_references.shape\n        expected_shape = (model.dofs(),)\n\n        if shape != expected_shape:\n            return False\n\n        shape = self._link_forces.shape\n        expected_shape = (model.number_of_links(), 6)\n\n        if shape != expected_shape:\n            return False\n\n        return True\n\n    # ==================\n    # Extract quantities\n    # ==================\n    @js.common.named_scope\n    @functools.partial(jax.jit, static_argnames=[\"link_names\"])\n    def link_forces(\n        self,\n        model: js.model.JaxSimModel | None = None,\n        data: js.data.JaxSimModelData | None = None,\n        link_names: tuple[str, ...] | None = None,\n    ) -> jtp.Matrix:\n        \"\"\"\n        Return the link forces expressed in the frame of the active representation.\n\n        Args:\n            model: The model to consider.\n            data: The data of the considered model.\n            link_names: The names of the links corresponding to the forces.\n\n        Returns:\n            If no model and no link names are provided, the link forces as a\n            `(n_links,6)` matrix corresponding to the default link serialization\n            of the original model used to build the actuation object.\n            If a model is provided and no link names are provided, the link forces\n            as a `(n_links,6)` matrix corresponding to the serialization of the\n            provided model.\n            If both a model and link names are provided, the link forces as a\n            `(len(link_names),6)` matrix corresponding to the serialization of\n            the passed link names vector.\n\n        Note:\n            The returned link forces are those passed as user inputs when integrating\n            the dynamics of the model. They are summed with other forces related\n            e.g. to the contact model and other kinematic constraints.\n        \"\"\"\n\n        W_f_L = self._link_forces\n\n        # Return all link forces in inertial-fixed representation using the implicit\n        # serialization.\n        if model is None:\n            if self.velocity_representation is not VelRepr.Inertial:\n                msg = \"Missing model to use a representation different from {}\"\n                raise ValueError(msg.format(VelRepr.Inertial.name))\n\n            if link_names is not None:\n                raise ValueError(\"Link names cannot be provided without a model\")\n\n            return W_f_L\n\n        # If we have the model, we can extract the link names, if not provided.\n        link_idxs = (\n            js.link.names_to_idxs(link_names=link_names, model=model)\n            if link_names is not None\n            else jnp.arange(model.number_of_links())\n        )\n\n        # In inertial-fixed representation, we already have the link forces.\n        if self.velocity_representation is VelRepr.Inertial:\n            return W_f_L[link_idxs, :]\n\n        if data is None:\n            msg = \"Missing model data to use a representation different from {}\"\n            raise ValueError(msg.format(VelRepr.Inertial.name))\n\n        if not_tracing(self._link_forces) and not data.valid(model=model):\n            raise ValueError(\"The provided data is not valid for the model\")\n\n        # Helper function to convert a single 6D force to the active representation\n        # considering as body the link (i.e. L_f_L and LW_f_L).\n        def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix:\n\n            return jax.vmap(\n                lambda W_f_L, W_H_L: JaxSimModelReferences.inertial_to_other_representation(\n                    array=W_f_L,\n                    other_representation=self.velocity_representation,\n                    transform=W_H_L,\n                    is_force=True,\n                )\n            )(W_f_L, W_H_L)\n\n        # The f_L output is either L_f_L or LW_f_L, depending on the representation.\n        W_H_L = data._link_transforms\n        f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :])\n\n        return f_L\n\n    def joint_force_references(\n        self,\n        model: js.model.JaxSimModel | None = None,\n        joint_names: tuple[str, ...] | None = None,\n    ) -> jtp.Vector:\n        \"\"\"\n        Return the joint force references.\n\n        Args:\n            model: The model to consider.\n            joint_names: The names of the joints corresponding to the forces.\n\n        Returns:\n            If no model and no joint names are provided, the joint forces as a\n            `(DoFs,)` vector corresponding to the default joint serialization\n            of the original model used to build the actuation object.\n            If a model is provided and no joint names are provided, the joint forces\n            as a `(DoFs,)` vector corresponding to the serialization of the\n            provided model.\n            If both a model and joint names are provided, the joint forces as a\n            `(len(joint_names),)` vector corresponding to the serialization of\n            the passed joint names vector.\n\n        Note:\n            The returned joint forces are those passed as user inputs when integrating\n            the dynamics of the model. They are summed with other joint forces related\n            e.g. to the enforcement of other kinematic constraints. Keep also in mind\n            that the presence of joint friction and other similar effects can make the\n            actual joint forces different from the references.\n        \"\"\"\n\n        if model is None:\n            if joint_names is not None:\n                raise ValueError(\"Joint names cannot be provided without a model\")\n\n            return self._joint_force_references\n\n        if not_tracing(self._joint_force_references) and not self.valid(model=model):\n            msg = \"The actuation object is not compatible with the provided model\"\n            raise ValueError(msg)\n\n        joint_idxs = (\n            js.joint.names_to_idxs(joint_names=joint_names, model=model)\n            if joint_names is not None\n            else jnp.arange(model.number_of_joints())\n        )\n\n        return jnp.atleast_1d(\n            self._joint_force_references[joint_idxs].squeeze()\n        ).astype(float)\n\n    # ================\n    # Store quantities\n    # ================\n    @js.common.named_scope\n    @functools.partial(jax.jit, static_argnames=[\"joint_names\"])\n    def set_joint_force_references(\n        self,\n        forces: jtp.VectorLike,\n        model: js.model.JaxSimModel | None = None,\n        joint_names: tuple[str, ...] | None = None,\n    ) -> Self:\n        \"\"\"\n        Set the joint force references.\n\n        Args:\n            forces: The joint force references.\n            model:\n                The model to consider, only needed if a joint serialization different\n                from the implicit one is used.\n            joint_names: The names of the joints corresponding to the forces.\n\n        Returns:\n            A new `JaxSimModelReferences` object with the given joint force references.\n        \"\"\"\n\n        forces = jnp.atleast_1d(jnp.array(forces, dtype=float).squeeze())\n\n        def replace(forces: jtp.Vector) -> JaxSimModelReferences:\n            return self.replace(\n                validate=True,\n                _joint_force_references=jnp.atleast_1d(forces.squeeze()).astype(float),\n            )\n\n        if model is None:\n            return replace(forces=forces)\n\n        if not_tracing(forces) and not self.valid(model=model):\n            msg = \"The references object is not compatible with the provided model\"\n            raise ValueError(msg)\n\n        joint_idxs = (\n            js.joint.names_to_idxs(joint_names=joint_names, model=model)\n            if joint_names is not None\n            else jnp.arange(model.number_of_joints())\n        )\n\n        return replace(forces=self._joint_force_references.at[joint_idxs].set(forces))\n\n    @js.common.named_scope\n    @functools.partial(jax.jit, static_argnames=[\"link_names\", \"additive\"])\n    def apply_link_forces(\n        self,\n        forces: jtp.MatrixLike,\n        model: js.model.JaxSimModel | None = None,\n        data: js.data.JaxSimModelData | None = None,\n        link_names: tuple[str, ...] | str | None = None,\n        additive: bool = False,\n    ) -> Self:\n        \"\"\"\n        Apply the link forces.\n\n        Args:\n            forces: The link 6D forces in the active representation.\n            model:\n                The model to consider, only needed if a link serialization different\n                from the implicit one is used.\n            data:\n                The data of the considered model, only needed if the velocity\n                representation is not inertial-fixed.\n            link_names: The names of the links corresponding to the forces.\n            additive:\n                Whether to add the forces to the existing ones instead of replacing them.\n\n        Returns:\n            A new `JaxSimModelReferences` object with the given link forces.\n\n        Note:\n            The link forces must be expressed in the active representation.\n            Then, we always convert and store forces in inertial-fixed representation.\n        \"\"\"\n\n        f_L = jnp.atleast_2d(forces).astype(float)\n\n        # Helper function to replace the link forces.\n        def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences:\n            return self.replace(\n                validate=True,\n                _link_forces=jnp.atleast_2d(forces.squeeze()).astype(float),\n            )\n\n        # In this case, we allow only to set the inertial 6D forces to all links\n        # using the implicit link serialization.\n        if model is None:\n            if self.velocity_representation is not VelRepr.Inertial:\n                msg = \"Missing model to use a representation different from {}\"\n                raise ValueError(msg.format(VelRepr.Inertial.name))\n\n            if link_names is not None:\n                raise ValueError(\"Link names cannot be provided without a model\")\n\n            W_f_L = f_L\n\n            W_f0_L = jnp.zeros_like(W_f_L) if not additive else self._link_forces\n\n            return replace(forces=W_f0_L + W_f_L)\n\n        if link_names is not None and len(link_names) != f_L.shape[0]:\n            msg = \"The number of link names ({}) must match the number of forces ({})\"\n            raise ValueError(msg.format(len(link_names), f_L.shape[0]))\n\n        # Extract the link indices.\n        link_idxs = (\n            js.link.names_to_idxs(link_names=link_names, model=model)\n            if link_names is not None\n            else jnp.arange(model.number_of_links())\n        )\n\n        # Compute the bias depending on whether we either set or add the link forces.\n        W_f0_L = (\n            jnp.zeros_like(f_L) if not additive else self._link_forces[link_idxs, :]\n        )\n\n        # If inertial-fixed representation, we can directly store the link forces.\n        if self.velocity_representation is VelRepr.Inertial:\n            W_f_L = f_L\n            return replace(\n                forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L)\n            )\n\n        if data is None:\n            msg = \"Missing model data to use a representation different from {}\"\n            raise ValueError(msg.format(VelRepr.Inertial.name))\n\n        if not_tracing(forces) and not data.valid(model=model):\n            raise ValueError(\"The provided data is not valid for the model\")\n\n        W_H_L = data._link_transforms\n\n        # Convert a single 6D force to the inertial representation\n        # considering as body the link (i.e. L_f_L and LW_f_L).\n        # The f_L input is either L_f_L or LW_f_L, depending on the representation.\n        W_f_L = JaxSimModelReferences.other_representation_to_inertial(\n            array=f_L,\n            other_representation=self.velocity_representation,\n            transform=W_H_L[link_idxs] if model.number_of_links() > 1 else W_H_L,\n            is_force=True,\n        )\n\n        return replace(forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L))\n\n    def apply_frame_forces(\n        self,\n        forces: jtp.MatrixLike,\n        model: js.model.JaxSimModel,\n        data: js.data.JaxSimModelData,\n        frame_names: tuple[str, ...] | str | None = None,\n        additive: bool = False,\n    ) -> Self:\n        \"\"\"\n        Apply the frame forces.\n\n        Args:\n            forces: The frame 6D forces in the active representation.\n            model:\n                The model to consider, only needed if a frame serialization different\n                from the implicit one is used.\n            data:\n                The data of the considered model, only needed if the velocity\n                representation is not inertial-fixed.\n            frame_names: The names of the frames corresponding to the forces.\n            additive:\n                Whether to add the forces to the existing ones instead of replacing them.\n\n        Returns:\n            A new `JaxSimModelReferences` object with the given frame forces.\n\n        Note:\n            The frame forces must be expressed in the active representation.\n            Then, we always convert and store forces in inertial-fixed representation.\n        \"\"\"\n\n        f_F = jnp.atleast_2d(forces).astype(float)\n\n        if len(frame_names) != f_F.shape[0]:\n            msg = \"The number of frame names ({}) must match the number of forces ({})\"\n            raise ValueError(msg.format(len(frame_names), f_F.shape[0]))\n\n        # Extract the frame indices.\n        frame_idxs = (\n            js.frame.names_to_idxs(frame_names=frame_names, model=model)\n            if frame_names is not None\n            else jnp.arange(len(model.frame_names()))\n        )\n\n        parent_link_idxs = jnp.array(model.kin_dyn_parameters.frame_parameters.body)[\n            frame_idxs - model.number_of_links()\n        ]\n\n        exceptions.raise_value_error_if(\n            condition=~data.valid(model=model),\n            msg=\"The provided data is not valid for the model\",\n        )\n        W_H_Fi = jax.vmap(\n            lambda frame_idx: js.frame.transform(\n                model=model, data=data, frame_index=frame_idx\n            )\n        )(frame_idxs)\n\n        # Helper function to convert a single 6D force to the inertial representation\n        # considering as body the frame (i.e. L_f_F and LW_f_F).\n        def to_inertial(f_F: jtp.MatrixLike, W_H_F: jtp.MatrixLike) -> jtp.Matrix:\n            return JaxSimModelReferences.other_representation_to_inertial(\n                array=f_F,\n                other_representation=self.velocity_representation,\n                transform=W_H_F,\n                is_force=True,\n            )\n\n        match self.velocity_representation:\n            case VelRepr.Inertial:\n                W_f_F = f_F\n\n            case VelRepr.Body | VelRepr.Mixed:\n                W_f_F = jax.vmap(to_inertial)(f_F, W_H_Fi)\n\n            case _:\n                raise ValueError(\"Invalid velocity representation.\")\n\n        # Sum the forces on the parent links.\n        mask = parent_link_idxs[:, jnp.newaxis] == jnp.arange(model.number_of_links())\n        W_f_L = mask.T @ W_f_F\n\n        with self.switch_velocity_representation(\n            velocity_representation=VelRepr.Inertial\n        ):\n            references = self.apply_link_forces(\n                model=model,\n                data=data,\n                forces=W_f_L,\n                additive=additive,\n            )\n\n        with references.switch_velocity_representation(\n            velocity_representation=self.velocity_representation\n        ):\n            return references\n"
  },
  {
    "path": "src/jaxsim/exceptions.py",
    "content": "import os\n\nimport jax\n\n\ndef raise_if(\n    condition: bool | jax.Array, exception: type, msg: str, *args, **kwargs\n) -> None:\n    \"\"\"\n    Raise a host-side exception if a condition is met. Useful in jit-compiled functions.\n\n    Args:\n        condition:\n            The boolean condition of the evaluated expression that triggers\n            the exception during runtime.\n        exception: The type of exception to raise.\n        msg:\n            The message to display when the exception is raised. The message can be a\n            format string (fmt), whose fields are filled with the args and kwargs.\n        *args: The arguments to fill the format string.\n        **kwargs: The keyword arguments to fill the format string\n    \"\"\"\n\n    # Disable host callback if running on unsupported hardware or if the user\n    # explicitly disabled it.\n    if jax.devices()[0].platform in {\"tpu\", \"METAL\"} or not os.environ.get(\n        \"JAXSIM_ENABLE_EXCEPTIONS\", 0\n    ):\n        return\n\n    # Check early that the format string is well-formed.\n    try:\n        _ = msg.format(*args, **kwargs)\n    except Exception as e:\n        msg = \"Error in formatting exception message with args={} and kwargs={}\"\n        raise ValueError(msg.format(args, kwargs)) from e\n\n    def _raise_exception(condition: bool, *args, **kwargs) -> None:\n        \"\"\"The function called by the JAX callback.\"\"\"\n\n        if condition:\n            raise exception(msg.format(*args, **kwargs))\n\n    def _callback(args, kwargs) -> None:\n        \"\"\"The function that calls the JAX callback, executed only when needed.\"\"\"\n\n        jax.debug.callback(_raise_exception, condition, *args, **kwargs)\n\n    # Since running a callable on the host is expensive, we prevent its execution\n    # if the condition is False with a low-level conditional expression.\n    def _run_callback_only_if_condition_is_true(*args, **kwargs) -> None:\n        return jax.lax.cond(\n            condition,\n            _callback,\n            lambda args, kwargs: None,\n            args,\n            kwargs,\n        )\n\n    return _run_callback_only_if_condition_is_true(*args, **kwargs)\n\n\ndef raise_runtime_error_if(\n    condition: bool | jax.Array, msg: str, *args, **kwargs\n) -> None:\n    \"\"\"\n    Raise a RuntimeError if a condition is met. Useful in jit-compiled functions.\n    \"\"\"\n\n    return raise_if(condition, RuntimeError, msg, *args, **kwargs)\n\n\ndef raise_value_error_if(\n    condition: bool | jax.Array, msg: str, *args, **kwargs\n) -> None:\n    \"\"\"\n    Raise a ValueError if a condition is met. Useful in jit-compiled functions.\n    \"\"\"\n\n    return raise_if(condition, ValueError, msg, *args, **kwargs)\n"
  },
  {
    "path": "src/jaxsim/logging.py",
    "content": "import enum\nimport inspect\nimport logging\nimport os\nimport warnings\n\nimport coloredlogs\n\n\nclass JaxSimWarning(UserWarning):\n    pass\n\n\n_original_showwarning = warnings.showwarning\n_jaxsim_root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), \"..\"))\n\n\ndef pretty_jaxsim_warning(message, category, filename, lineno, file=None, line=None):\n    try:\n        caller_frame = inspect.stack()[2]\n        caller_file = caller_frame.filename\n    except Exception:\n        caller_file = filename\n\n    if caller_file.startswith(_jaxsim_root_dir):\n        print(f\"\\033[93m⚠️  {category.__name__}:\\033[0m {message}\")\n        print(f\"   → {filename}:{lineno}\")\n    else:\n        _original_showwarning(message, category, filename, lineno, file, line)\n\n\n# Register filter & formatter only for JaxSimWarning\n# and configure it to show each warning only once\nwarnings.showwarning = pretty_jaxsim_warning\nwarnings.filterwarnings(\"once\")\n\n\n# Utility function to issue a JaxSim warning\ndef jaxsim_warn(msg):\n    warnings.warn(msg, category=JaxSimWarning, stacklevel=2)\n\n\nLOGGER_NAME = \"jaxsim\"\n\n\nclass LoggingLevel(enum.IntEnum):\n    NOTSET = logging.NOTSET\n    DEBUG = logging.DEBUG\n    INFO = logging.INFO\n    WARNING = logging.WARNING\n    ERROR = logging.ERROR\n    CRITICAL = logging.CRITICAL\n\n\ndef _logger() -> logging.Logger:\n    return logging.getLogger(name=LOGGER_NAME)\n\n\ndef set_logging_level(level: int | LoggingLevel = LoggingLevel.WARNING):\n    if isinstance(level, int):\n        level = LoggingLevel(level)\n\n    _logger().setLevel(level=level.value)\n\n\ndef get_logging_level() -> LoggingLevel:\n    level = _logger().getEffectiveLevel()\n    return LoggingLevel(level)\n\n\ndef configure(level: LoggingLevel = LoggingLevel.WARNING) -> None:\n    info(\"Configuring the 'jaxsim' logger\")\n\n    handler = logging.StreamHandler()\n    fmt = \"%(name)s[%(process)d] %(levelname)s %(message)s\"\n    handler.setFormatter(fmt=coloredlogs.ColoredFormatter(fmt=fmt))\n    _logger().addHandler(hdlr=handler)\n\n    # Do not propagate the messages to handlers of parent loggers\n    # (preventing duplicate logging)\n    _logger().propagate = False\n\n    set_logging_level(level=level)\n\n\ndef debug(msg: str = \"\") -> None:\n    _logger().debug(msg=msg)\n\n\ndef info(msg: str = \"\") -> None:\n    _logger().info(msg=msg)\n\n\ndef warning(msg: str = \"\") -> None:\n    _logger().warning(msg=msg)\n\n\ndef error(msg: str = \"\") -> None:\n    _logger().error(msg=msg)\n\n\ndef critical(msg: str = \"\") -> None:\n    _logger().critical(msg=msg)\n\n\ndef exception(msg: str = \"\") -> None:\n    _logger().exception(msg=msg)\n"
  },
  {
    "path": "src/jaxsim/math/__init__.py",
    "content": "from .adjoint import Adjoint\nfrom .cross import Cross\nfrom .inertia import Inertia\nfrom .quaternion import Quaternion\nfrom .rotation import Rotation\nfrom .skew import Skew\nfrom .transform import Transform\nfrom .utils import safe_norm\n\nfrom .joint_model import JointModel, supported_joint_motion  # isort:skip\n\n# Define the default standard gravity constant.\nSTANDARD_GRAVITY = 9.81\n"
  },
  {
    "path": "src/jaxsim/math/adjoint.py",
    "content": "import jax.numpy as jnp\nimport jaxlie\n\nimport jaxsim.typing as jtp\n\nfrom .skew import Skew\n\n\nclass Adjoint:\n    \"\"\"\n    A utility class for adjoint matrix operations.\n    \"\"\"\n\n    @staticmethod\n    def from_quaternion_and_translation(\n        quaternion: jtp.Vector | None = None,\n        translation: jtp.Vector | None = None,\n        inverse: bool = False,\n        normalize_quaternion: bool = False,\n    ) -> jtp.Matrix:\n        \"\"\"\n        Create an adjoint matrix from a quaternion and a translation.\n\n        Args:\n            quaternion: A quaternion vector (4D) representing orientation.\n            translation: A translation vector (3D).\n            inverse: Whether to compute the inverse adjoint.\n            normalize_quaternion: Whether to normalize the quaternion before creating the adjoint.\n\n        Returns:\n            The adjoint matrix.\n        \"\"\"\n        quaternion = quaternion if quaternion is not None else jnp.array([1.0, 0, 0, 0])\n        translation = translation if translation is not None else jnp.zeros(3)\n        assert quaternion.size == 4\n        assert translation.size == 3\n\n        Q_sixd = jaxlie.SO3(wxyz=quaternion)\n        Q_sixd = Q_sixd if not normalize_quaternion else Q_sixd.normalize()\n\n        return Adjoint.from_rotation_and_translation(\n            rotation=Q_sixd.as_matrix(), translation=translation, inverse=inverse\n        )\n\n    @staticmethod\n    def from_transform(transform: jtp.MatrixLike, inverse: bool = False) -> jtp.Matrix:\n        \"\"\"\n        Create an adjoint matrix from a transformation matrix.\n\n        Args:\n            transform: A 4x4 transformation matrix.\n            inverse: Whether to compute the inverse adjoint.\n\n        Returns:\n            The 6x6 adjoint matrix.\n        \"\"\"\n\n        A_H_B = transform\n\n        return (\n            jaxlie.SE3.from_matrix(matrix=A_H_B).adjoint()\n            if not inverse\n            else jaxlie.SE3.from_matrix(matrix=A_H_B).inverse().adjoint()\n        )\n\n    @staticmethod\n    def from_rotation_and_translation(\n        rotation: jtp.Matrix | None = None,\n        translation: jtp.Vector | None = None,\n        inverse: bool = False,\n    ) -> jtp.Matrix:\n        \"\"\"\n        Create an adjoint matrix from a rotation matrix and a translation vector.\n\n        Args:\n            rotation: A 3x3 rotation matrix.\n            translation: A translation vector (3D).\n            inverse: Whether to compute the inverse adjoint. Default is False.\n\n        Returns:\n            The adjoint matrix.\n        \"\"\"\n        rotation = rotation if rotation is not None else jnp.eye(3)\n        translation = translation if translation is not None else jnp.zeros(3)\n\n        assert rotation.shape == (3, 3)\n        assert translation.size == 3\n\n        A_R_B = rotation.squeeze()\n        A_o_B = translation.squeeze()\n\n        if not inverse:\n            X = A_X_B = jnp.vstack(  # noqa: F841\n                [\n                    jnp.block([A_R_B, Skew.wedge(A_o_B) @ A_R_B]),\n                    jnp.block([jnp.zeros(shape=(3, 3)), A_R_B]),\n                ]\n            )\n        else:\n            X = B_X_A = jnp.vstack(  # noqa: F841\n                [\n                    jnp.block([A_R_B.T, -A_R_B.T @ Skew.wedge(A_o_B)]),\n                    jnp.block([jnp.zeros(shape=(3, 3)), A_R_B.T]),\n                ]\n            )\n\n        return X\n\n    @staticmethod\n    def to_transform(adjoint: jtp.Matrix) -> jtp.Matrix:\n        \"\"\"\n        Convert an adjoint matrix to a transformation matrix.\n\n        Args:\n            adjoint: The adjoint matrix (6x6).\n\n        Returns:\n            The transformation matrix (4x4).\n        \"\"\"\n        X = adjoint.squeeze()\n        assert X.shape == (6, 6)\n\n        R = X[0:3, 0:3]\n        o_x_R = X[0:3, 3:6]\n\n        H = jnp.vstack(\n            [\n                jnp.block([R, Skew.vee(matrix=o_x_R @ R.T)]),\n                jnp.array([0, 0, 0, 1]),\n            ]\n        )\n\n        return H\n\n    @staticmethod\n    def inverse(adjoint: jtp.Matrix) -> jtp.Matrix:\n        \"\"\"\n        Compute the inverse of an adjoint matrix.\n\n        Args:\n            adjoint: The adjoint matrix.\n\n        Returns:\n            The inverse adjoint matrix.\n        \"\"\"\n        A_X_B = adjoint.reshape(-1, 6, 6)\n\n        A_R_B_T = jnp.swapaxes(A_X_B[..., 0:3, 0:3], -2, -1)\n        A_T_B = A_X_B[..., 0:3, 3:6]\n\n        return jnp.concatenate(\n            [\n                jnp.concatenate(\n                    [A_R_B_T, -A_R_B_T @ A_T_B @ A_R_B_T],\n                    axis=-1,\n                ),\n                jnp.concatenate([jnp.zeros_like(A_R_B_T), A_R_B_T], axis=-1),\n            ],\n            axis=-2,\n        ).reshape(adjoint.shape)\n"
  },
  {
    "path": "src/jaxsim/math/cross.py",
    "content": "import jax.numpy as jnp\n\nimport jaxsim.typing as jtp\n\nfrom .skew import Skew\n\n\nclass Cross:\n    \"\"\"\n    A utility class for cross product matrix operations.\n    \"\"\"\n\n    @staticmethod\n    def vx(velocity_sixd: jtp.Vector) -> jtp.Matrix:\n        \"\"\"\n        Compute the cross product matrix for 6D velocities.\n\n        Args:\n            velocity_sixd: A 6D velocity vector [v, ω].\n\n        Returns:\n            The cross product matrix (6x6).\n\n        Raises:\n            ValueError: If the input vector does not have a size of 6.\n        \"\"\"\n        velocity_sixd = velocity_sixd.reshape(-1, 6)\n\n        v, ω = jnp.split(velocity_sixd, 2, axis=-1)\n\n        v_cross = jnp.concatenate(\n            [\n                jnp.concatenate(\n                    [Skew.wedge(ω), jnp.zeros((ω.shape[0], 3, 3)).squeeze()], axis=-2\n                ),\n                jnp.concatenate([Skew.wedge(v), Skew.wedge(ω)], axis=-2),\n            ],\n            axis=-1,\n        )\n\n        return v_cross\n\n    @staticmethod\n    def vx_star(velocity_sixd: jtp.Vector) -> jtp.Matrix:\n        \"\"\"\n        Compute the negative transpose of the cross product matrix for 6D velocities.\n\n        Args:\n            velocity_sixd: A 6D velocity vector [v, ω].\n\n        Returns:\n            The negative transpose of the cross product matrix (6x6).\n\n        Raises:\n            ValueError: If the input vector does not have a size of 6.\n        \"\"\"\n        v_cross_star = -Cross.vx(velocity_sixd).T\n        return v_cross_star\n"
  },
  {
    "path": "src/jaxsim/math/inertia.py",
    "content": "import jax.numpy as jnp\n\nimport jaxsim.typing as jtp\n\nfrom .skew import Skew\n\n\nclass Inertia:\n    \"\"\"\n    A utility class for inertia matrix operations.\n    \"\"\"\n\n    @staticmethod\n    def to_sixd(mass: jtp.Float, com: jtp.Vector, I: jtp.Matrix) -> jtp.Matrix:\n        \"\"\"\n        Convert mass, center of mass, and inertia matrix to a 6x6 inertia matrix.\n\n        Args:\n            mass: The mass of the body.\n            com: The center of mass position (3D).\n            I: The 3x3 inertia matrix.\n\n        Returns:\n            The 6x6 inertia matrix.\n\n        Raises:\n            ValueError: If the shape of the inertia matrix I is not (3, 3).\n        \"\"\"\n        if I.shape != (3, 3):\n            raise ValueError(I, I.shape)\n\n        c = Skew.wedge(vector=com)\n\n        M = jnp.vstack(\n            [\n                jnp.block([mass * jnp.eye(3), mass * c.T]),\n                jnp.block([mass * c, I + mass * c @ c.T]),\n            ]\n        )\n\n        return M\n\n    @staticmethod\n    def to_params(M: jtp.Matrix) -> tuple[jtp.Float, jtp.Vector, jtp.Matrix]:\n        \"\"\"\n        Convert a 6x6 inertia matrix to mass, center of mass, and inertia matrix.\n\n        Args:\n            M: The 6x6 inertia matrix.\n\n        Returns:\n            A tuple containing mass, center of mass (3D), and inertia matrix (3x3).\n\n        Raises:\n            ValueError: If the input matrix M has an unexpected shape.\n        \"\"\"\n        m = jnp.diag(M[0:3, 0:3]).sum() / 3\n\n        mC = M[3:6, 0:3]\n        c = Skew.vee(mC) / m\n        I = M[3:6, 3:6] - (mC @ mC.T / m)\n\n        return m, c, I\n"
  },
  {
    "path": "src/jaxsim/math/joint_model.py",
    "content": "from __future__ import annotations\n\nimport jax\nimport jax.numpy as jnp\nimport jax_dataclasses\nimport jaxlie\nfrom jax_dataclasses import Static\n\nimport jaxsim.typing as jtp\nfrom jaxsim.math import Rotation\nfrom jaxsim.parsers.descriptions import JointGenericAxis, JointType, ModelDescription\nfrom jaxsim.parsers.kinematic_graph import KinematicGraphTransforms\nfrom jaxsim.utils.jaxsim_dataclass import JaxsimDataclass\n\n\n@jax_dataclasses.pytree_dataclass\nclass JointModel(JaxsimDataclass):\n    \"\"\"\n    Class describing the joint kinematics of a robot model.\n\n    Attributes:\n        λ_H_pre:\n            The homogeneous transformation between the parent link and\n            the predecessor frame of each joint.\n        suc_H_i:\n            The homogeneous transformation between the successor frame and\n            the child link of each joint.\n        joint_dofs: The number of DoFs of each joint.\n        joint_names: The names of each joint.\n        joint_types: The types of each joint.\n\n    Note:\n        Due to the presence of the static attributes, this class needs to be created\n        already in a vectorized form. In other words, it cannot be created using vmap.\n    \"\"\"\n\n    λ_H_pre: jtp.Array\n    suc_H_i: jtp.Array\n\n    joint_dofs: Static[tuple[int, ...]]\n    joint_names: Static[tuple[str, ...]]\n    joint_types: Static[tuple[int, ...]]\n    joint_axis: Static[tuple[JointGenericAxis, ...]]\n\n    @staticmethod\n    def build(description: ModelDescription) -> JointModel:\n        \"\"\"\n        Build the joint model of a model description.\n\n        Args:\n            description: The model description to consider.\n\n        Returns:\n            The joint model of the considered model description.\n        \"\"\"\n\n        # The link index is equal to its body index: [0, number_of_bodies - 1].\n        ordered_links = sorted(\n            list(description.links_dict.values()),\n            key=lambda l: l.index,\n        )\n\n        # Note: the joint index is equal to its child link index, therefore it\n        # starts from 1.\n        ordered_joints = sorted(\n            list(description.joints_dict.values()),\n            key=lambda j: j.index,\n        )\n\n        # Allocate the parent-to-predecessor and successor-to-child transforms.\n        λ_H_pre = jnp.zeros(shape=(1 + len(ordered_joints), 4, 4), dtype=float)\n        suc_H_i = jnp.zeros(shape=(1 + len(ordered_joints), 4, 4), dtype=float)\n\n        # Initialize an identical parent-to-predecessor transform for the joint\n        # between the world frame W and the base link B.\n        λ_H_pre = λ_H_pre.at[0].set(jnp.eye(4))\n\n        # Initialize the successor-to-child transform of the joint between the\n        # world frame W and the base link B.\n        # We store here the optional transform between the root frame of the model\n        # and the base link frame (this is needed only if the pose of the link frame\n        # w.r.t. the implicit __model__ SDF frame is not the identity).\n        suc_H_i = suc_H_i.at[0].set(ordered_links[0].pose)\n\n        # Create the object to compute forward kinematics.\n        fk = KinematicGraphTransforms(graph=description)\n\n        # Compute the parent-to-predecessor and successor-to-child transforms for\n        # each joint belonging to the model.\n        # Note that the joint indices starts from i=1 given our joint model,\n        # therefore the entries at index 0 are not updated.\n        for joint in ordered_joints:\n            λ_H_pre = λ_H_pre.at[joint.index].set(\n                fk.relative_transform(relative_to=joint.parent.name, name=joint.name)\n            )\n            suc_H_i = suc_H_i.at[joint.index].set(\n                fk.relative_transform(relative_to=joint.name, name=joint.child.name)\n            )\n\n        # Define the DoFs of the base link.\n        base_dofs = 0 if description.fixed_base else 6\n\n        # We always add a dummy fixed joint between world and base.\n        # TODO: Port floating-base support also at this level, not only in RBDAs.\n        return JointModel(\n            λ_H_pre=λ_H_pre,\n            suc_H_i=suc_H_i,\n            # Static attributes\n            joint_dofs=tuple([base_dofs] + [1 for _ in ordered_joints]),\n            joint_names=tuple([\"world_to_base\"] + [j.name for j in ordered_joints]),\n            joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]),\n            joint_axis=tuple(JointGenericAxis(axis=j.axis) for j in ordered_joints),\n        )\n\n    def parent_H_predecessor(self, joint_index: jtp.IntLike) -> jtp.Matrix:\n        r\"\"\"\n        Return the homogeneous transformation between the parent link and\n        the predecessor frame of a joint.\n\n        Args:\n            joint_index: The index of the joint.\n\n        Returns:\n            The homogeneous transformation\n            :math:`{}^{\\lambda(i)} \\mathbf{H}_{\\text{pre}(i)}`.\n        \"\"\"\n\n        return self.λ_H_pre[joint_index]\n\n    def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix:\n        r\"\"\"\n        Return the homogeneous transformation between the successor frame and\n        the child link of a joint.\n\n        Args:\n            joint_index: The index of the joint.\n\n        Returns:\n            The homogeneous transformation\n            :math:`{}^{\\text{suc}(i)} \\mathbf{H}_i`.\n        \"\"\"\n\n        return self.suc_H_i[joint_index]\n\n\n@jax.jit\ndef supported_joint_motion(\n    joint_types: jtp.Array, joint_positions: jtp.Matrix, joint_axes: jtp.Matrix\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the transforms of the joints.\n\n    Args:\n        joint_types: The types of the joints.\n        joint_positions: The positions of the joints.\n        joint_axes: The axes of the joints.\n\n    Returns:\n        The transforms of the joints.\n    \"\"\"\n\n    # Prepare the joint position\n    s = jnp.array(joint_positions).astype(float)\n\n    def compute_F() -> tuple[jtp.Matrix, jtp.Array]:\n        return jaxlie.SE3.identity()\n\n    def compute_R() -> tuple[jtp.Matrix, jtp.Array]:\n\n        # Get the additional argument specifying the joint axis.\n        # This is a metadata required by only some joint types.\n        axis = jnp.array(joint_axes).astype(float).squeeze()\n\n        pre_H_suc = jaxlie.SE3.from_matrix(\n            matrix=jnp.eye(4).at[:3, :3].set(Rotation.from_axis_angle(vector=s * axis))\n        )\n\n        return pre_H_suc\n\n    def compute_P() -> tuple[jtp.Matrix, jtp.Array]:\n\n        # Get the additional argument specifying the joint axis.\n        # This is a metadata required by only some joint types.\n        axis = jnp.array(joint_axes).astype(float).squeeze()\n\n        pre_H_suc = jaxlie.SE3.from_rotation_and_translation(\n            rotation=jaxlie.SO3.identity(),\n            translation=jnp.array(s * axis),\n        )\n\n        return pre_H_suc\n\n    return jax.lax.switch(\n        index=joint_types,\n        branches=(\n            compute_F,  # JointType.Fixed\n            compute_R,  # JointType.Revolute\n            compute_P,  # JointType.Prismatic\n        ),\n    ).as_matrix()\n"
  },
  {
    "path": "src/jaxsim/math/quaternion.py",
    "content": "import jax.lax\nimport jax.numpy as jnp\nimport jaxlie\n\nimport jaxsim.typing as jtp\n\nfrom .utils import safe_norm\n\n\nclass Quaternion:\n    \"\"\"\n    A utility class for quaternion operations.\n    \"\"\"\n\n    @staticmethod\n    def to_xyzw(wxyz: jtp.Vector) -> jtp.Vector:\n        \"\"\"\n        Convert a quaternion from WXYZ to XYZW representation.\n\n        Args:\n            wxyz: Quaternion in WXYZ representation.\n\n        Returns:\n            Quaternion in XYZW representation.\n        \"\"\"\n        return wxyz.squeeze()[jnp.array([1, 2, 3, 0])]\n\n    @staticmethod\n    def to_wxyz(xyzw: jtp.Vector) -> jtp.Vector:\n        \"\"\"\n        Convert a quaternion from XYZW to WXYZ representation.\n\n        Args:\n            xyzw: Quaternion in XYZW representation.\n\n        Returns:\n            Quaternion in WXYZ representation.\n        \"\"\"\n        return xyzw.squeeze()[jnp.array([3, 0, 1, 2])]\n\n    @staticmethod\n    def to_dcm(quaternion: jtp.Vector) -> jtp.Matrix:\n        \"\"\"\n        Convert a quaternion to a direction cosine matrix (DCM).\n\n        Args:\n            quaternion: Quaternion in XYZW representation.\n\n        Returns:\n            The Direction cosine matrix (DCM).\n        \"\"\"\n        return jaxlie.SO3(wxyz=quaternion).as_matrix()\n\n    @staticmethod\n    def from_dcm(dcm: jtp.Matrix) -> jtp.Vector:\n        \"\"\"\n        Convert a direction cosine matrix (DCM) to a quaternion.\n\n        Args:\n            dcm: Direction cosine matrix (DCM).\n\n        Returns:\n            Quaternion in WXYZ representation.\n        \"\"\"\n        return jaxlie.SO3.from_matrix(matrix=dcm).wxyz\n\n    @staticmethod\n    def derivative(\n        quaternion: jtp.Vector,\n        omega: jtp.Vector,\n        omega_in_body_fixed: bool = False,\n        K: float = 0.1,\n    ) -> jtp.Vector:\n        \"\"\"\n        Compute the derivative of a quaternion given angular velocity.\n\n        Args:\n            quaternion: Quaternion in XYZW representation.\n            omega: Angular velocity vector.\n            omega_in_body_fixed (bool): Whether the angular velocity is in the body-fixed frame.\n            K (float): A scaling factor.\n\n        Returns:\n            The derivative of the quaternion.\n        \"\"\"\n        ω = omega.squeeze()\n        quaternion = quaternion.squeeze()\n\n        def Q_body(q: jtp.Vector) -> jtp.Matrix:\n            qw, qx, qy, qz = q\n\n            return jnp.array(\n                [\n                    [qw, -qx, -qy, -qz],\n                    [qx, qw, -qz, qy],\n                    [qy, qz, qw, -qx],\n                    [qz, -qy, qx, qw],\n                ]\n            )\n\n        def Q_inertial(q: jtp.Vector) -> jtp.Matrix:\n            qw, qx, qy, qz = q\n\n            return jnp.array(\n                [\n                    [qw, -qx, -qy, -qz],\n                    [qx, qw, qz, -qy],\n                    [qy, -qz, qw, qx],\n                    [qz, qy, -qx, qw],\n                ]\n            )\n\n        Q = jax.lax.cond(\n            pred=omega_in_body_fixed,\n            true_fun=Q_body,\n            false_fun=Q_inertial,\n            operand=quaternion,\n        )\n\n        norm_ω = safe_norm(ω)\n\n        qd = 0.5 * (\n            Q\n            @ jnp.hstack(\n                [\n                    K * norm_ω * (1 - safe_norm(quaternion)),\n                    ω,\n                ]\n            )\n        )\n\n        return jnp.vstack(qd)\n\n    @staticmethod\n    def integration(\n        quaternion: jtp.VectorLike,\n        dt: jtp.FloatLike,\n        omega: jtp.VectorLike,\n        omega_in_body_fixed: jtp.BoolLike = False,\n    ) -> jtp.Vector:\n        \"\"\"\n        Integrate a quaternion in SO(3) given an angular velocity.\n\n        Args:\n            quaternion: The quaternion to integrate.\n            dt: The time step.\n            omega: The angular velocity vector.\n            omega_in_body_fixed:\n                Whether the angular velocity is in body-fixed representation\n                as opposed to the default inertial-fixed representation.\n\n        Returns:\n            The integrated quaternion.\n        \"\"\"\n\n        ω_AB = jnp.array(omega).squeeze().astype(float)\n        A_Q_B = jnp.array(quaternion).squeeze().astype(float)\n\n        # Build the initial SO(3) quaternion.\n        W_Q_B_t0 = jaxlie.SO3(wxyz=A_Q_B)\n\n        # Integrate the quaternion on the manifold.\n        W_Q_B_tf = jax.lax.select(\n            pred=omega_in_body_fixed,\n            on_true=(W_Q_B_t0 @ jaxlie.SO3.exp(tangent=dt * ω_AB)).wxyz,\n            on_false=(jaxlie.SO3.exp(tangent=dt * ω_AB) @ W_Q_B_t0).wxyz,\n        )\n\n        return W_Q_B_tf\n"
  },
  {
    "path": "src/jaxsim/math/rotation.py",
    "content": "import jax.numpy as jnp\nimport jaxlie\n\nimport jaxsim.typing as jtp\n\nfrom .skew import Skew\nfrom .utils import safe_norm\n\n\nclass Rotation:\n    \"\"\"\n    A utility class for rotation matrix operations.\n    \"\"\"\n\n    @staticmethod\n    def x(theta: jtp.Float) -> jtp.Matrix:\n        \"\"\"\n        Generate a 3D rotation matrix around the X-axis.\n\n        Args:\n            theta: Rotation angle in radians.\n\n        Returns:\n            The 3D rotation matrix.\n        \"\"\"\n\n        return jaxlie.SO3.from_x_radians(theta=theta).as_matrix()\n\n    @staticmethod\n    def y(theta: jtp.Float) -> jtp.Matrix:\n        \"\"\"\n        Generate a 3D rotation matrix around the Y-axis.\n\n        Args:\n            theta: Rotation angle in radians.\n\n        Returns:\n            The 3D rotation matrix.\n        \"\"\"\n\n        return jaxlie.SO3.from_y_radians(theta=theta).as_matrix()\n\n    @staticmethod\n    def z(theta: jtp.Float) -> jtp.Matrix:\n        \"\"\"\n        Generate a 3D rotation matrix around the Z-axis.\n\n        Args:\n            theta: Rotation angle in radians.\n\n        Returns:\n            The 3D rotation matrix.\n        \"\"\"\n\n        return jaxlie.SO3.from_z_radians(theta=theta).as_matrix()\n\n    @staticmethod\n    def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix:\n        \"\"\"\n        Generate a 3D rotation matrix from an axis-angle representation.\n\n        Args:\n            vector: Axis-angle representation or the rotation as a 3D vector.\n\n        Returns:\n            The SO(3) rotation matrix.\n        \"\"\"\n\n        vector = vector.squeeze()\n\n        theta = safe_norm(vector)\n\n        s = jnp.sin(theta)\n        c = jnp.cos(theta)\n\n        c1 = 2 * jnp.sin(theta / 2.0) ** 2\n\n        safe_theta = jnp.where(theta == 0, 1.0, theta)\n        u = vector / safe_theta\n        u = jnp.vstack(u.squeeze())\n\n        R = c * jnp.eye(3) - s * Skew.wedge(u) + c1 * u @ u.T\n\n        return R.transpose()\n\n    @staticmethod\n    def log_vee(R: jnp.ndarray) -> jtp.Vector:\n        \"\"\"\n        Compute the logarithm map of an SO(3) rotation matrix.\n\n        Args:\n            R: The SO(3) rotation matrix.\n\n        Returns:\n            The corresponding so(3) tangent vector.\n        \"\"\"\n\n        return jaxlie.SO3.from_matrix(R).log()\n"
  },
  {
    "path": "src/jaxsim/math/skew.py",
    "content": "import jax.numpy as jnp\n\nimport jaxsim.typing as jtp\n\n\nclass Skew:\n    \"\"\"\n    A utility class for skew-symmetric matrix operations.\n    \"\"\"\n\n    @staticmethod\n    def wedge(vector: jtp.Vector) -> jtp.Matrix:\n        \"\"\"\n        Compute the skew-symmetric matrix (wedge operator) of a 3D vector.\n\n        Args:\n            vector: A 3D vector.\n\n        Returns:\n            The skew-symmetric matrix corresponding to the input vector.\n\n        \"\"\"\n\n        vector = vector.reshape(-1, 3)\n\n        x, y, z = jnp.split(vector, 3, axis=-1)\n\n        skew = jnp.stack(\n            [\n                jnp.concatenate([jnp.zeros_like(x), -z, y], axis=-1),\n                jnp.concatenate([z, jnp.zeros_like(x), -x], axis=-1),\n                jnp.concatenate([-y, x, jnp.zeros_like(x)], axis=-1),\n            ],\n            axis=-2,\n        ).squeeze()\n\n        return skew\n\n    @staticmethod\n    def vee(matrix: jtp.Matrix) -> jtp.Vector:\n        \"\"\"\n        Extract the 3D vector from a skew-symmetric matrix (vee operator).\n\n        Args:\n            matrix: A 3x3 skew-symmetric matrix.\n\n        Returns:\n            The 3D vector extracted from the input matrix.\n\n        \"\"\"\n        vector = 0.5 * jnp.vstack(\n            [\n                matrix[2, 1] - matrix[1, 2],\n                matrix[0, 2] - matrix[2, 0],\n                matrix[1, 0] - matrix[0, 1],\n            ]\n        )\n        return vector\n"
  },
  {
    "path": "src/jaxsim/math/transform.py",
    "content": "import jax.numpy as jnp\nimport jaxlie\n\nimport jaxsim.typing as jtp\n\n\nclass Transform:\n    \"\"\"\n    A utility class for transformation matrix operations.\n    \"\"\"\n\n    @staticmethod\n    def from_quaternion_and_translation(\n        quaternion: jtp.VectorLike | None = None,\n        translation: jtp.VectorLike | None = None,\n        inverse: jtp.BoolLike = False,\n        normalize_quaternion: jtp.BoolLike = False,\n    ) -> jtp.Matrix:\n        \"\"\"\n        Create a transformation matrix from a quaternion and a translation.\n\n        Args:\n            quaternion: A 4D vector representing a SO(3) orientation.\n            translation: A 3D vector representing a translation.\n            inverse: Whether to compute the inverse transformation.\n            normalize_quaternion:\n                Whether to normalize the quaternion before creating the transformation.\n\n        Returns:\n            The 4x4 transformation matrix representing the SE(3) transformation.\n        \"\"\"\n\n        quaternion = quaternion if quaternion is not None else jnp.array([1.0, 0, 0, 0])\n        translation = translation if translation is not None else jnp.zeros(3)\n\n        W_Q_B = jnp.array(quaternion).astype(float)\n        W_p_B = jnp.array(translation).astype(float)\n\n        assert W_p_B.shape[-1] == 3\n        assert W_Q_B.shape[-1] == 4\n\n        A_R_B = jaxlie.SO3(wxyz=W_Q_B)\n        A_R_B = A_R_B if not normalize_quaternion else A_R_B.normalize()\n\n        A_H_B = jaxlie.SE3.from_rotation_and_translation(\n            rotation=A_R_B, translation=W_p_B\n        )\n\n        return A_H_B.as_matrix() if not inverse else A_H_B.inverse().as_matrix()\n\n    @staticmethod\n    def from_rotation_and_translation(\n        rotation: jtp.MatrixLike | None = None,\n        translation: jtp.VectorLike | None = None,\n        inverse: jtp.BoolLike = False,\n    ) -> jtp.Matrix:\n        \"\"\"\n        Create a transformation matrix from a rotation matrix and a translation vector.\n\n        Args:\n            rotation: A 3x3 rotation matrix representing a SO(3) orientation.\n            translation: A 3D vector representing a translation.\n            inverse: Whether to compute the inverse transformation.\n\n        Returns:\n            The 4x4 transformation matrix representing the SE(3) transformation.\n        \"\"\"\n        rotation = rotation if rotation is not None else jnp.eye(3)\n        translation = translation if translation is not None else jnp.zeros(3)\n\n        A_R_B = jnp.array(rotation).astype(float)\n        W_p_B = jnp.array(translation).astype(float)\n\n        assert W_p_B.size == 3\n        assert A_R_B.shape == (3, 3)\n\n        A_H_B = jaxlie.SE3.from_rotation_and_translation(\n            rotation=jaxlie.SO3.from_matrix(A_R_B), translation=W_p_B\n        )\n\n        return A_H_B.as_matrix() if not inverse else A_H_B.inverse().as_matrix()\n\n    @staticmethod\n    def inverse(transform: jtp.MatrixLike) -> jtp.Matrix:\n        \"\"\"\n        Compute the inverse transformation matrix.\n\n        Args:\n            transform: A 4x4 transformation matrix.\n\n        Returns:\n            The 4x4 inverse transformation matrix.\n        \"\"\"\n\n        return jaxlie.SE3.from_matrix(matrix=transform).inverse().as_matrix()\n"
  },
  {
    "path": "src/jaxsim/math/utils.py",
    "content": "import jax\nimport jax.numpy as jnp\n\nimport jaxsim.typing as jtp\n\n\ndef _make_safe_norm(axis, keepdims):\n    @jax.custom_jvp\n    def _safe_norm(array: jtp.ArrayLike) -> jtp.Array:\n        \"\"\"\n        Compute an array norm handling NaNs and making sure that\n        it is safe to get the gradient.\n\n        Args:\n            array: The array for which to compute the norm.\n\n        Returns:\n            The norm of the array with handling for zero arrays to avoid NaNs.\n        \"\"\"\n        # Compute the norm of the array along the specified axis.\n        return jnp.linalg.norm(array, axis=axis, keepdims=keepdims)\n\n    @_safe_norm.defjvp\n    def _safe_norm_jvp(primals, tangents):\n        (x,), (x_dot,) = primals, tangents\n\n        # Check if the entire array is composed of zeros.\n        is_zero = jnp.allclose(x, 0)\n\n        # Replace zeros with an array of ones temporarily to avoid division by zero.\n        # This ensures the computation of norm does not produce NaNs or Infs.\n        array = jnp.where(is_zero, jnp.ones_like(x), x)\n\n        # Compute the norm of the array along the specified axis.\n        norm = jnp.linalg.norm(array, axis=axis, keepdims=keepdims)\n\n        dot = jnp.sum(array * x_dot, axis=axis, keepdims=keepdims)\n        tangent = jnp.where(is_zero, 0.0, dot / norm)\n\n        return jnp.where(is_zero, 0.0, norm), tangent\n\n    return _safe_norm\n\n\ndef safe_norm(array: jtp.ArrayLike, *, axis=None, keepdims: bool = False) -> jtp.Array:\n    \"\"\"\n    Compute an array norm handling NaNs and making sure that\n    it is safe to get the gradient.\n\n    Args:\n        array: The array for which to compute the norm.\n        axis: The axis for which to compute the norm.\n        keepdims: Whether to keep the dimensions of the input\n\n    Returns:\n        The norm of the array with handling for zero arrays to avoid NaNs.\n    \"\"\"\n    return _make_safe_norm(axis, keepdims)(array)\n"
  },
  {
    "path": "src/jaxsim/mujoco/__init__.py",
    "content": "from .loaders import ModelToMjcf, RodModelToMjcf, SdfToMjcf, UrdfToMjcf\nfrom .model import MujocoModelHelper\nfrom .utils import MujocoCamera, mujoco_data_from_jaxsim\nfrom .visualizer import MujocoVideoRecorder, MujocoVisualizer\n"
  },
  {
    "path": "src/jaxsim/mujoco/__main__.py",
    "content": "import argparse\nimport pathlib\nimport sys\nimport time\n\nimport numpy as np\n\nfrom . import ModelToMjcf, MujocoModelHelper, MujocoVisualizer\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser(\n        prog=\"jaxsim.mujoco\",\n        description=\"Process URDF and SDF files for Mujoco usage.\",\n    )\n\n    parser.add_argument(\n        \"-d\",\n        \"--description\",\n        required=True,\n        metavar=\"INPUT_FILE\",\n        type=pathlib.Path,\n        help=\"Path to the URDF or SDF file.\",\n    )\n\n    parser.add_argument(\n        \"-m\",\n        \"--model-name\",\n        metavar=\"NAME\",\n        type=str,\n        default=None,\n        help=\"The target model of a SDF description if multiple models exists.\",\n    )\n\n    parser.add_argument(\n        \"-e\",\n        \"--export\",\n        metavar=\"MJCF_FILE\",\n        type=pathlib.Path,\n        default=None,\n        help=\"Path to the exported MJCF file.\",\n    )\n\n    parser.add_argument(\n        \"-f\",\n        \"--force\",\n        action=\"store_true\",\n        default=False,\n        help=\"Override the output MJCF file if it already exists (default: %(default)s).\",\n    )\n\n    parser.add_argument(\n        \"-p\",\n        \"--print\",\n        action=\"store_true\",\n        default=False,\n        help=\"Print in the stdout the exported MJCF string (default: %(default)s).\",\n    )\n\n    parser.add_argument(\n        \"-v\",\n        \"--visualize\",\n        action=\"store_true\",\n        default=False,\n        help=\"Visualize the description in the Mujoco viewer (default: %(default)s).\",\n    )\n\n    parser.add_argument(\n        \"-b\",\n        \"--base-position\",\n        metavar=(\"x\", \"y\", \"z\"),\n        nargs=3,\n        type=float,\n        default=None,\n        help=\"Override the base position (supports only floating-base models).\",\n    )\n\n    parser.add_argument(\n        \"-q\",\n        \"--base-quaternion\",\n        metavar=(\"w\", \"x\", \"y\", \"z\"),\n        nargs=4,\n        type=float,\n        default=None,\n        help=\"Override the base quaternion (supports only floating-base models).\",\n    )\n\n    args = parser.parse_args()\n\n    # ==================\n    # Validate arguments\n    # ==================\n\n    # Expand the path of the URDF/SDF file if not absolute.\n    if args.description is not None:\n        args.description = (\n            (\n                args.description\n                if args.description.is_absolute()\n                else pathlib.Path.cwd() / args.description\n            )\n            .expanduser()\n            .absolute()\n        )\n\n        if not pathlib.Path(args.description).is_file():\n            msg = f\"The URDF/SDF file '{args.description}' does not exist.\"\n            parser.error(msg)\n            sys.exit(1)\n\n    # Expand the path of the output MJCF file if not absolute.\n    if args.export is not None:\n        args.export = (\n            (\n                args.export\n                if args.export.is_absolute()\n                else pathlib.Path.cwd() / args.export\n            )\n            .expanduser()\n            .absolute()\n        )\n\n        if pathlib.Path(args.export).is_file() and not args.force:\n            msg = \"The output file '{}' already exists, use '--force' to override.\"\n            parser.error(msg.format(args.export))\n            sys.exit(1)\n\n    # ================================================\n    # Load the URDF/SDF file and produce a MJCF string\n    # ================================================\n\n    mjcf_string, assets = ModelToMjcf.convert(args.description)\n\n    if args.print:\n        print(mjcf_string, flush=True)\n\n    # ========================================\n    # Write the MJCF string to the output file\n    # ========================================\n\n    if args.export is not None:\n        with open(args.export, \"w+\", encoding=\"utf-8\") as file:\n            file.write(mjcf_string)\n\n    # =======================================\n    # Visualize the MJCF in the Mujoco viewer\n    # =======================================\n\n    if args.visualize:\n\n        mj_model_helper = MujocoModelHelper.build_from_xml(\n            mjcf_description=mjcf_string, assets=assets\n        )\n\n        viz = MujocoVisualizer(model=mj_model_helper.model, data=mj_model_helper.data)\n\n        with viz.open() as viewer:\n\n            with viewer.lock():\n                if args.base_position is not None:\n                    mj_model_helper.set_base_position(\n                        position=np.array(args.base_position)\n                    )\n\n                if args.base_quaternion is not None:\n                    mj_model_helper.set_base_orientation(\n                        orientation=np.array(args.base_quaternion)\n                    )\n\n            viz.sync(viewer=viewer)\n\n            while viewer.is_running():\n                time.sleep(0.500)\n\n    # =============================\n    # Exit the program with success\n    # =============================\n\n    sys.exit(0)\n"
  },
  {
    "path": "src/jaxsim/mujoco/loaders.py",
    "content": "import pathlib\nimport tempfile\nimport warnings\nfrom collections.abc import Sequence\nfrom typing import Any\n\nimport jaxlie\nimport mujoco as mj\nimport numpy as np\nimport rod.urdf.exporter\nfrom lxml import etree as ET\n\nfrom jaxsim import logging\n\nfrom .utils import MujocoCamera\n\nMujocoCameraType = (\n    MujocoCamera | Sequence[MujocoCamera] | dict[str, str] | Sequence[dict[str, str]]\n)\n\n\ndef load_rod_model(\n    model_description: str | pathlib.Path | rod.Model,\n    is_urdf: bool | None = None,\n    model_name: str | None = None,\n) -> rod.Model:\n    \"\"\"\n    Load a ROD model from a URDF/SDF file or a ROD model.\n\n    Args:\n        model_description: The URDF/SDF file or ROD model to load.\n        is_urdf: Whether to force parsing the model description as a URDF file.\n        model_name: The name of the model to load from the resource.\n\n    Returns:\n        rod.Model: The loaded ROD model.\n    \"\"\"\n\n    # Parse the SDF resource.\n    sdf_element = rod.Sdf.load(sdf=model_description, is_urdf=is_urdf)\n\n    # Fail if the SDF resource has no model.\n    if len(sdf_element.models()) == 0:\n        raise RuntimeError(\"Failed to find any model in the model description\")\n\n    # Return the model if there is only one.\n    if len(sdf_element.models()) == 1:\n        if model_name is not None and sdf_element.models()[0].name != model_name:\n            raise ValueError(f\"Model '{model_name}' not found in the description\")\n\n        return sdf_element.models()[0]\n\n    # Require users to specify the model name if there are multiple models.\n    if model_name is None:\n        msg = \"The resource has multiple models. Please specify the model name.\"\n        raise ValueError(msg)\n\n    # Build a dictionary of models in the resource for easy access.\n    models = {m.name: m for m in sdf_element.models()}\n\n    if model_name not in models:\n        raise ValueError(f\"Model '{model_name}' not found in the resource\")\n\n    return models[model_name]\n\n\nclass ModelToMjcf:\n    \"\"\"\n    Class to convert a URDF/SDF file or a ROD model to a Mujoco MJCF string.\n    \"\"\"\n\n    @staticmethod\n    def convert(\n        model: str | pathlib.Path | rod.Model,\n        considered_joints: list[str] | None = None,\n        plane_normal: tuple[float, float, float] = (0, 0, 1),\n        heightmap: bool | None = None,\n        heightmap_samples_xy: tuple[int, int] = (101, 101),\n        cameras: MujocoCameraType = (),\n    ) -> tuple[str, dict[str, Any]]:\n        \"\"\"\n        Convert a model to a Mujoco MJCF string.\n\n        Args:\n            model: The URDF/SDF file or ROD model to convert.\n            considered_joints: The list of joint names to consider in the conversion.\n            plane_normal: The normal vector of the plane.\n            heightmap: Whether to generate a heightmap.\n            heightmap_samples_xy: The number of points in the heightmap grid.\n            cameras: The custom cameras to add to the scene.\n\n        Returns:\n            A tuple containing the MJCF string and the dictionary of assets.\n        \"\"\"\n\n        match model:\n            case rod.Model():\n                rod_model = model\n            case str() | pathlib.Path():\n                # Convert the JaxSim model to a ROD model.\n                rod_model = load_rod_model(\n                    model_description=model,\n                    is_urdf=None,\n                    model_name=None,\n                )\n            case _:\n                raise TypeError(f\"Unsupported type for 'model': {type(model)}\")\n\n        # Convert the ROD model to MJCF.\n        return RodModelToMjcf.convert(\n            rod_model=rod_model,\n            considered_joints=considered_joints,\n            plane_normal=plane_normal,\n            heightmap=heightmap,\n            heightmap_samples_xy=heightmap_samples_xy,\n            cameras=cameras,\n        )\n\n\nclass RodModelToMjcf:\n    \"\"\"\n    Class to convert a ROD model to a Mujoco MJCF string.\n    \"\"\"\n\n    @staticmethod\n    def assets_from_rod_model(\n        rod_model: rod.Model,\n    ) -> dict[str, bytes]:\n        \"\"\"\n        Generate a dictionary of assets from a ROD model.\n\n        Args:\n            rod_model: The ROD model to extract the assets from.\n\n        Returns:\n            dict: A dictionary of assets.\n        \"\"\"\n\n        import resolve_robotics_uri_py\n\n        assets_files = dict()\n\n        for link in rod_model.links():\n            for visual in link.visuals():\n                if visual.geometry.mesh and visual.geometry.mesh.uri:\n                    assets_files[visual.geometry.mesh.uri] = (\n                        resolve_robotics_uri_py.resolve_robotics_uri(\n                            visual.geometry.mesh.uri\n                        )\n                    )\n\n            for collision in link.collisions():\n                if collision.geometry.mesh and collision.geometry.mesh.uri:\n                    assets_files[collision.geometry.mesh.uri] = (\n                        resolve_robotics_uri_py.resolve_robotics_uri(\n                            collision.geometry.mesh.uri\n                        )\n                    )\n\n        assets = {\n            asset_name: asset.read_bytes() for asset_name, asset in assets_files.items()\n        }\n\n        return assets\n\n    @staticmethod\n    def add_floating_joint(\n        urdf_string: str,\n        base_link_name: str,\n        floating_joint_name: str = \"world_to_base\",\n    ) -> str:\n        \"\"\"\n        Add a floating joint to a URDF string.\n\n        Args:\n            urdf_string: The URDF string to modify.\n            base_link_name: The name of the base link to attach the floating joint.\n            floating_joint_name: The name of the floating joint to add.\n\n        Returns:\n            str: The modified URDF string.\n        \"\"\"\n\n        with tempfile.NamedTemporaryFile(mode=\"w+\", suffix=\".urdf\") as urdf_file:\n\n            # Write the URDF string to a temporary file and move current position\n            # to the beginning.\n            urdf_file.write(urdf_string)\n            urdf_file.seek(0)\n\n            # Parse the MJCF string as XML (etree).\n            parser = ET.XMLParser(remove_blank_text=True)\n            tree = ET.parse(source=urdf_file, parser=parser)\n\n        root: ET._Element = tree.getroot()\n\n        if root.find(f\".//joint[@name='{floating_joint_name}']\") is not None:\n            msg = f\"The URDF already has a floating joint '{floating_joint_name}'\"\n            warnings.warn(msg, stacklevel=2)\n            return ET.tostring(root, pretty_print=True).decode()\n\n        # Create the \"world\" link if it doesn't exist.\n        if root.find(\".//link[@name='world']\") is None:\n            _ = ET.SubElement(root, \"link\", name=\"world\")\n\n        # Create the floating joint.\n        world_to_base = ET.SubElement(\n            root, \"joint\", name=floating_joint_name, type=\"floating\"\n        )\n\n        # Check that the base link exists.\n        if root.find(f\".//link[@name='{base_link_name}']\") is None:\n            raise ValueError(f\"Link '{base_link_name}' not found in the URDF\")\n\n        # Attach the floating joint to the base link.\n        ET.SubElement(world_to_base, \"parent\", link=\"world\")\n        ET.SubElement(world_to_base, \"child\", link=base_link_name)\n\n        urdf_string = ET.tostring(root, pretty_print=True).decode()\n        return urdf_string\n\n    @staticmethod\n    def convert(\n        rod_model: rod.Model,\n        considered_joints: list[str] | None = None,\n        plane_normal: tuple[float, float, float] = (0, 0, 1),\n        heightmap: bool | None = None,\n        heightmap_samples_xy: tuple[int, int] = (101, 101),\n        cameras: MujocoCameraType = (),\n    ) -> tuple[str, dict[str, Any]]:\n        \"\"\"\n        Convert a ROD model to a Mujoco MJCF string.\n\n        Args:\n            rod_model: The ROD model to convert.\n            considered_joints: The list of joint names to consider in the conversion.\n            plane_normal: The normal vector of the plane.\n            heightmap: Whether to generate a heightmap.\n            heightmap_samples_xy: The number of points in the heightmap grid.\n            cameras: The custom cameras to add to the scene.\n\n        Returns:\n            A tuple containing the MJCF string and the dictionary of assets.\n        \"\"\"\n\n        # -------------------------------------\n        # Convert the model description to URDF\n        # -------------------------------------\n\n        # Consider all joints if not specified otherwise.\n        considered_joints = set(\n            considered_joints\n            if considered_joints is not None\n            else [j.name for j in rod_model.joints() if j.type != \"fixed\"]\n        )\n\n        # If considered joints are passed, make sure that they are all part of the model.\n        if considered_joints - {j.name for j in rod_model.joints()}:\n            extra_joints = considered_joints - {j.name for j in rod_model.joints()}\n\n            msg = f\"Couldn't find the following joints in the model: '{extra_joints}'\"\n            raise ValueError(msg)\n\n        # Create a dictionary of joints for quick access.\n        joints_dict = {j.name: j for j in rod_model.joints()}\n\n        # Convert all the joints not considered to fixed joints.\n        for joint_name in {j.name for j in rod_model.joints()} - considered_joints:\n            joints_dict[joint_name].type = \"fixed\"\n\n        # Convert the ROD model to URDF.\n        urdf_string = rod.urdf.exporter.UrdfExporter(\n            gazebo_preserve_fixed_joints=False, pretty=True\n        ).to_urdf_string(\n            sdf=rod.Sdf(model=rod_model, version=\"1.7\"),\n        )\n\n        # -------------------------------------\n        # Add a floating joint if floating-base\n        # -------------------------------------\n\n        base_link_name = rod_model.get_canonical_link()\n\n        if not rod_model.is_fixed_base():\n            considered_joints |= {\"world_to_base\"}\n            urdf_string = RodModelToMjcf.add_floating_joint(\n                urdf_string=urdf_string,\n                base_link_name=base_link_name,\n                floating_joint_name=\"world_to_base\",\n            )\n\n        # ---------------------------------------\n        # Inject the <mujoco> element in the URDF\n        # ---------------------------------------\n\n        parser = ET.XMLParser(remove_blank_text=True)\n        root = ET.fromstring(text=urdf_string.encode(), parser=parser)\n\n        mujoco_element = (\n            ET.SubElement(root, \"mujoco\")\n            if len(root.findall(\"./mujoco\")) == 0\n            else root.find(\"./mujoco\")\n        )\n\n        _ = ET.SubElement(\n            mujoco_element,\n            \"compiler\",\n            balanceinertia=\"true\",\n            discardvisual=\"false\",\n        )\n\n        urdf_string = ET.tostring(root, pretty_print=True).decode()\n\n        # ------------------------------\n        # Post-process all dummy visuals\n        # ------------------------------\n\n        parser = ET.XMLParser(remove_blank_text=True)\n        root: ET._Element = ET.fromstring(text=urdf_string.encode(), parser=parser)\n\n        # Give a tiny radius to all dummy spheres\n        for geometry in root.findall(\".//visual/geometry[sphere]\"):\n            radius = np.fromstring(\n                geometry.find(\"./sphere\").attrib[\"radius\"], sep=\" \", dtype=float\n            )\n            if np.allclose(radius, np.zeros(1)):\n                geometry.find(\"./sphere\").set(\"radius\", \"0.001\")\n\n        # Give a tiny volume to all dummy boxes\n        for geometry in root.findall(\".//visual/geometry[box]\"):\n            size = np.fromstring(\n                geometry.find(\"./box\").attrib[\"size\"], sep=\" \", dtype=float\n            )\n            if np.allclose(size, np.zeros(3)):\n                geometry.find(\"./box\").set(\"size\", \"0.001 0.001 0.001\")\n\n        urdf_string = ET.tostring(root, pretty_print=True).decode()\n\n        # ------------------------\n        # Convert the URDF to MJCF\n        # ------------------------\n\n        # Load the URDF model into Mujoco.\n        assets = RodModelToMjcf.assets_from_rod_model(rod_model=rod_model)\n        mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets)\n\n        # Get the joint names.\n        mj_joint_names = {\n            mj.mj_id2name(mj_model, mj.mjtObj.mjOBJ_JOINT, idx)\n            for idx in range(mj_model.njnt)\n        }\n\n        # Check that the Mujoco model only has the considered joints.\n        if mj_joint_names != considered_joints:\n            extra1 = mj_joint_names - considered_joints\n            extra2 = considered_joints - mj_joint_names\n            extra_joints = extra1.union(extra2)\n            msg = \"The Mujoco model has the following extra/missing joints: '{}'\"\n            raise ValueError(msg.format(extra_joints))\n\n        # Windows locks open files, so we use mkstemp() to create a temporary file without keeping it open.\n        with tempfile.NamedTemporaryFile(\n            suffix=\".xml\", prefix=f\"{rod_model.name}_\", delete=False\n        ) as tmp:\n            temp_filename = tmp.name\n\n        try:\n            # Convert the in-memory Mujoco model to MJCF.\n            mj.mj_saveLastXML(temp_filename, mj_model)\n\n            # Parse the MJCF file as XML.\n            parser = ET.XMLParser(remove_blank_text=True)\n            tree = ET.parse(source=temp_filename, parser=parser)\n\n        finally:\n            pathlib.Path(temp_filename).unlink()\n\n        # Get the root element.\n        root: ET._Element = tree.getroot()\n\n        # Find the <mujoco> element (might be the root itself).\n        mujoco_element: ET._Element = next(iter(root.iter(\"mujoco\")))\n\n        # --------------\n        # Add the frames\n        # --------------\n\n        for frame in rod_model.frames():\n            frame: rod.Frame\n            parent_name = frame.attached_to\n            parent_element = mujoco_element.find(f\".//body[@name='{parent_name}']\")\n\n            if parent_element is None and parent_name == base_link_name:\n                parent_element = mujoco_element.find(\".//worldbody\")\n\n            if parent_element is not None:\n                quat = jaxlie.SO3.from_rpy_radians(*frame.pose.rpy).wxyz\n                _ = ET.SubElement(\n                    parent_element,\n                    \"site\",\n                    name=frame.name,\n                    pos=\" \".join(map(str, frame.pose.xyz)),\n                    quat=\" \".join(map(str, quat)),\n                )\n            else:\n                logging.debug(f\"Parent link '{parent_name}' not found\")\n\n        # --------------\n        # Add the motors\n        # --------------\n\n        if len(mujoco_element.findall(\".//actuator\")) > 0:\n            raise RuntimeError(\"The model already has <actuator> elements.\")\n\n        # Add the actuator element.\n        actuator_element = ET.SubElement(mujoco_element, \"actuator\")\n\n        # Add a motor for each joint.\n        for joint_element in mujoco_element.findall(\".//joint\"):\n            assert (\n                joint_element.attrib[\"name\"] in considered_joints\n            ), joint_element.attrib[\"name\"]\n            if joint_element.attrib.get(\"type\", \"hinge\") in {\"free\", \"ball\"}:\n                continue\n            ET.SubElement(\n                actuator_element,\n                \"motor\",\n                name=f\"{joint_element.attrib['name']}_motor\",\n                joint=joint_element.attrib[\"name\"],\n                gear=\"1\",\n            )\n\n        # ---------------------------------------------\n        # Set full transparency of collision geometries\n        # ---------------------------------------------\n\n        parser = ET.XMLParser(remove_blank_text=True)\n\n        # Get all the (optional) names of the URDF collision elements\n        collision_names = {\n            c.attrib[\"name\"]\n            for c in ET.fromstring(text=urdf_string.encode(), parser=parser).findall(\n                \".//collision[geometry]\"\n            )\n            if \"name\" in c.attrib\n        }\n\n        # Set alpha=0 to the color of all collision elements\n        for geometry_element in mujoco_element.findall(\".//geom[@rgba]\"):\n            if geometry_element.attrib.get(\"name\") in collision_names:\n                r, g, b, _ = geometry_element.attrib[\"rgba\"].split(\" \")\n                geometry_element.set(\"rgba\", f\"{r} {g} {b} 0\")\n\n        # -----------------------\n        # Create the scene assets\n        # -----------------------\n\n        asset_element = (\n            ET.SubElement(mujoco_element, \"asset\")\n            if len(mujoco_element.findall(\".//asset\")) == 0\n            else mujoco_element.find(\".//asset\")\n        )\n\n        _ = ET.SubElement(\n            asset_element,\n            \"texture\",\n            type=\"skybox\",\n            builtin=\"gradient\",\n            rgb1=\"0.3 0.5 0.7\",\n            rgb2=\"0 0 0\",\n            width=\"512\",\n            height=\"512\",\n        )\n\n        _ = ET.SubElement(\n            asset_element,\n            \"texture\",\n            name=\"plane_texture\",\n            type=\"2d\",\n            builtin=\"checker\",\n            rgb1=\"0.1 0.2 0.3\",\n            rgb2=\"0.2 0.3 0.4\",\n            width=\"512\",\n            height=\"512\",\n            mark=\"cross\",\n            markrgb=\".8 .8 .8\",\n        )\n\n        _ = ET.SubElement(\n            asset_element,\n            \"material\",\n            name=\"plane_material\",\n            texture=\"plane_texture\",\n            reflectance=\"0.2\",\n            texrepeat=\"5 5\",\n            texuniform=\"true\",\n        )\n\n        _ = (\n            ET.SubElement(\n                asset_element,\n                \"hfield\",\n                name=\"terrain\",\n                nrow=str(int(heightmap_samples_xy[0])),\n                ncol=str(int(heightmap_samples_xy[1])),\n                # The following 'size' is a placeholder, it is updated dynamically\n                # when a hfield/heightmap is stored into MjData.\n                size=\"1 1 1 1\",\n            )\n            if heightmap\n            else None\n        )\n\n        # ----------------------------------\n        # Populate the scene with the assets\n        # ----------------------------------\n\n        worldbody_scene_element = ET.SubElement(mujoco_element, \"worldbody\")\n\n        _ = ET.SubElement(\n            worldbody_scene_element,\n            \"geom\",\n            name=\"floor\",\n            type=\"plane\" if not heightmap else \"hfield\",\n            size=\"0 0 0.05\",\n            material=\"plane_material\",\n            condim=\"3\",\n            contype=\"1\",\n            conaffinity=\"1\",\n            zaxis=\" \".join(map(str, plane_normal)),\n            **({\"hfield\": \"terrain\"} if heightmap else {}),\n        )\n\n        _ = ET.SubElement(\n            worldbody_scene_element,\n            \"light\",\n            name=\"sun\",\n            mode=\"fixed\",\n            directional=\"true\",\n            castshadow=\"true\",\n            pos=\"0 0 10\",\n            dir=\"0 0 -1\",\n        )\n\n        # -------------------------------------------------------\n        # Add a camera following the CoM of the worldbody element\n        # -------------------------------------------------------\n\n        worldbody_element = None\n\n        # Find the <worldbody> element of our model by searching the one that contains\n        # all the considered joints. This is needed because there might be multiple\n        # <worldbody> elements inside <mujoco>.\n        for wb in mujoco_element.findall(\".//worldbody\"):\n            if all(\n                wb.find(f\".//joint[@name='{j}']\") is not None for j in considered_joints\n            ):\n                worldbody_element = wb\n                break\n\n        if worldbody_element is None:\n            raise RuntimeError(\"Failed to find the <worldbody> element of the model\")\n\n        # Camera attached to the model\n        # It can be manually copied from `python -m mujoco.viewer --mjcf=<URDF_PATH>`\n        _ = ET.SubElement(\n            worldbody_element,\n            \"camera\",\n            name=\"track\",\n            mode=\"trackcom\",\n            pos=\"1.930 -2.279 0.556\",\n            xyaxes=\"0.771 0.637 0.000 -0.116 0.140 0.983\",\n            fovy=\"60\",\n        )\n\n        # Add user-defined camera.\n        for camera in cameras if isinstance(cameras, Sequence) else [cameras]:\n\n            mj_camera = (\n                camera\n                if isinstance(camera, MujocoCamera)\n                else MujocoCamera.build(**camera)\n            )\n\n            _ = ET.SubElement(worldbody_element, \"camera\", mj_camera.asdict())\n\n        # ------------------------------------------------\n        # Add a light following the  CoM of the first link\n        # ------------------------------------------------\n\n        if not rod_model.is_fixed_base():\n\n            # Light attached to the model\n            _ = ET.SubElement(\n                worldbody_element,\n                \"light\",\n                name=\"light_model\",\n                mode=\"targetbodycom\",\n                target=worldbody_element.find(\".//body\").attrib[\"name\"],\n                directional=\"false\",\n                castshadow=\"true\",\n                pos=\"1 0 5\",\n            )\n\n        # --------------------------------\n        # Return the resulting MJCF string\n        # --------------------------------\n\n        mjcf_string = ET.tostring(root, pretty_print=True).decode()\n        return mjcf_string, assets\n\n\nclass UrdfToMjcf:\n    \"\"\"\n    Class to convert a URDF file to a Mujoco MJCF string.\n    \"\"\"\n\n    @staticmethod\n    def convert(\n        urdf: str | pathlib.Path,\n        considered_joints: list[str] | None = None,\n        model_name: str | None = None,\n        plane_normal: tuple[float, float, float] = (0, 0, 1),\n        heightmap: bool | None = None,\n        cameras: MujocoCameraType = (),\n    ) -> tuple[str, dict[str, Any]]:\n        \"\"\"\n        Convert a URDF file to a Mujoco MJCF string.\n\n        Args:\n            urdf: The URDF file to convert.\n            considered_joints: The list of joint names to consider in the conversion.\n            model_name: The name of the model to convert.\n            plane_normal: The normal vector of the plane.\n            heightmap: Whether to generate a heightmap.\n            cameras: The list of cameras to add to the scene.\n\n        Returns:\n            tuple: A tuple containing the MJCF string and the assets dictionary.\n        \"\"\"\n\n        logging.warning(\"This method is deprecated. Use 'ModelToMjcf.convert' instead.\")\n\n        # Get the ROD model.\n        rod_model = load_rod_model(\n            model_description=urdf,\n            is_urdf=True,\n            model_name=model_name,\n        )\n\n        # Convert the ROD model to MJCF.\n        return RodModelToMjcf.convert(\n            rod_model=rod_model,\n            considered_joints=considered_joints,\n            plane_normal=plane_normal,\n            heightmap=heightmap,\n            cameras=cameras,\n        )\n\n\nclass SdfToMjcf:\n    \"\"\"\n    Class to convert a SDF file to a Mujoco MJCF string.\n    \"\"\"\n\n    @staticmethod\n    def convert(\n        sdf: str | pathlib.Path,\n        considered_joints: list[str] | None = None,\n        model_name: str | None = None,\n        plane_normal: tuple[float, float, float] = (0, 0, 1),\n        heightmap: bool | None = None,\n        cameras: MujocoCameraType = (),\n    ) -> tuple[str, dict[str, Any]]:\n        \"\"\"\n        Convert a SDF file to a Mujoco MJCF string.\n\n        Args:\n            sdf: The SDF file to convert.\n            considered_joints: The list of joint names to consider in the conversion.\n            model_name: The name of the model to convert.\n            plane_normal: The normal vector of the plane.\n            heightmap: Whether to generate a heightmap.\n            cameras: The list of cameras to add to the scene.\n\n        Returns:\n            tuple: A tuple containing the MJCF string and the assets dictionary.\n        \"\"\"\n\n        logging.warning(\"This method is deprecated. Use 'ModelToMjcf.convert' instead.\")\n\n        # Get the ROD model.\n        rod_model = load_rod_model(\n            model_description=sdf,\n            is_urdf=False,\n            model_name=model_name,\n        )\n\n        # Convert the ROD model to MJCF.\n        return RodModelToMjcf.convert(\n            rod_model=rod_model,\n            considered_joints=considered_joints,\n            plane_normal=plane_normal,\n            heightmap=heightmap,\n            cameras=cameras,\n        )\n"
  },
  {
    "path": "src/jaxsim/mujoco/model.py",
    "content": "from __future__ import annotations\n\nimport functools\nimport pathlib\nfrom collections.abc import Callable, Sequence\nfrom typing import Any\n\nimport mujoco as mj\nimport numpy as np\nimport numpy.typing as npt\nimport xmltodict\nfrom scipy.spatial.transform import Rotation\n\nimport jaxsim.typing as jtp\n\nHeightmapCallable = Callable[[jtp.FloatLike, jtp.FloatLike], jtp.FloatLike]\n\n\nclass MujocoModelHelper:\n    \"\"\"\n    Helper class to create and interact with Mujoco models and data objects.\n    \"\"\"\n\n    def __init__(self, model: mj.MjModel, data: mj.MjData | None = None) -> None:\n        \"\"\"\n        Initialize the MujocoModelHelper object.\n\n        Args:\n            model: A Mujoco model object.\n            data: A Mujoco data object. If None, a new one will be created.\n        \"\"\"\n\n        self.model = model\n        self.data = data if data is not None else mj.MjData(self.model)\n\n        # Populate the data with kinematics.\n        mj.mj_forward(self.model, self.data)\n\n        # Keep the cache of this method local to improve GC.\n        self.mask_qpos = functools.cache(self._mask_qpos)\n\n    @staticmethod\n    def build_from_xml(\n        mjcf_description: str | pathlib.Path,\n        assets: dict[str, Any] | None = None,\n        heightmap: HeightmapCallable | None = None,\n        heightmap_name: str = \"terrain\",\n        heightmap_radius_xy: tuple[float, float] = (1.0, 1.0),\n    ) -> MujocoModelHelper:\n        \"\"\"\n        Build a Mujoco model from an MJCF description.\n\n        Args:\n            mjcf_description:\n                A string containing the XML description of the Mujoco model\n                or a path to a file containing the XML description.\n            assets: An optional dictionary containing the assets of the model.\n            heightmap:\n                A function in two variables that returns the height of a terrain\n                in the specified coordinate point.\n            heightmap_name:\n                The default name of the heightmap in the MJCF description\n                to load the corresponding configuration.\n            heightmap_radius_xy:\n                The extension of the heightmap in the x-y surface corresponding to the\n                plane over which the grid of the sampled heightmap is generated.\n\n        Returns:\n            A MujocoModelHelper object.\n        \"\"\"\n\n        # Read the XML description if it is a path to file.\n        mjcf_description = (\n            mjcf_description.read_text()\n            if isinstance(mjcf_description, pathlib.Path)\n            else mjcf_description\n        )\n\n        if heightmap is None:\n            hfield = None\n\n        else:\n\n            mjcf_description_dict = xmltodict.parse(xml_input=mjcf_description)\n\n            # Create a dictionary of all hfield configurations from the MJCF.\n            hfields = mjcf_description_dict[\"mujoco\"][\"asset\"].get(\"hfield\", [])\n            hfields = hfields if isinstance(hfields, list) else [hfields]\n            hfields_dict = {hfield[\"@name\"]: hfield for hfield in hfields}\n\n            if heightmap_name not in hfields_dict:\n                raise ValueError(f\"Heightmap '{heightmap_name}' not found in MJCF\")\n\n            hfield_element = hfields_dict[heightmap_name]\n\n            # Generate the hfield by sampling the heightmap function.\n            hfield = generate_hfield(\n                heightmap=heightmap,\n                samples_xy=(int(hfield_element[\"@nrow\"]), int(hfield_element[\"@ncol\"])),\n                radius_xy=heightmap_radius_xy,\n            )\n\n            # Update dynamically the '/asset/hfield[@name=heightmap_name]@size' attribute\n            # with the information of the sampled points.\n            # This is necessary for correctly rendering the heightmap over the\n            # specified xy area with the correct z elevation.\n            size = [float(el) for el in hfield_element[\"@size\"].split(\" \")]\n            size[0], size[1] = heightmap_radius_xy\n            size[2] = 1.0\n            # The following could be zero but Mujoco complains if it's exactly zero.\n            size[3] = max(0.000_001, -min(hfield))\n\n            # Replace the 'size' attribute.\n            hfields_dict[heightmap_name][\"@size\"] = \" \".join(str(el) for el in size)\n\n            # Update the hfield elements of the original MJCF.\n            # Only the hfield corresponding to 'heightmap_name' was actually edited.\n            mjcf_description_dict[\"mujoco\"][\"asset\"][\"hfield\"] = list(\n                hfields_dict.values()\n            )\n\n            # Serialize the updated MJCF to XML.\n            mjcf_description = xmltodict.unparse(\n                input_dict=mjcf_description_dict, pretty=True\n            )\n\n        # Create the Mujoco model from the XML and, optionally, the dictionary of assets.\n        model = mj.MjModel.from_xml_string(xml=mjcf_description, assets=assets)\n        data = mj.MjData(model)\n\n        # Store the sampled heightmap into the Mujoco model.\n        if heightmap is not None:\n            assert hfield is not None\n            model.hfield_data = hfield\n\n        return MujocoModelHelper(model=model, data=data)\n\n    def time(self) -> float:\n        \"\"\"Return the simulation time.\"\"\"\n\n        return self.data.time\n\n    def timestep(self) -> float:\n        \"\"\"Return the simulation timestep.\"\"\"\n\n        return self.model.opt.timestep\n\n    def gravity(self) -> npt.NDArray:\n        \"\"\"Return the 3D gravity vector.\"\"\"\n\n        return np.array([0, 0, self.model.gravity])\n\n    # =========================\n    # Methods for the base link\n    # =========================\n\n    def is_floating_base(self) -> bool:\n        \"\"\"Return true if the model is floating-base.\"\"\"\n\n        # A body with no joints is considered a fixed-base model.\n        # In fact, in mujoco, a floating-base model has a 6 DoFs first joint.\n        if self.number_of_joints() == 0:\n            return False\n\n        # We just check that the first joint has 6 DoFs.\n        joint0_type = self.model.jnt_type[0]\n        return joint0_type == mj.mjtJoint.mjJNT_FREE\n\n    def is_fixed_base(self) -> bool:\n        \"\"\"Return true if the model is fixed-base.\"\"\"\n\n        return not self.is_floating_base()\n\n    def base_link(self) -> str:\n        \"\"\"Return the name of the base link.\"\"\"\n\n        return mj.mj_id2name(\n            self.model, mj.mjtObj.mjOBJ_BODY, 0 if self.is_fixed_base() else 1\n        )\n\n    def base_position(self) -> npt.NDArray:\n        \"\"\"Return the 3D position of the base link.\"\"\"\n\n        return (\n            self.data.qpos[:3]\n            if self.is_floating_base()\n            else self.body_position(body_name=self.base_link())\n        )\n\n    def base_orientation(self, dcm: bool = False) -> npt.NDArray:\n        \"\"\"Return the orientation of the base link.\"\"\"\n\n        return (\n            (\n                np.reshape(self.data.xmat[0], newshape=(3, 3))\n                if dcm\n                else self.data.xquat[0]\n            )\n            if self.is_floating_base()\n            else self.body_orientation(body_name=self.base_link(), dcm=dcm)\n        )\n\n    def set_base_position(self, position: npt.NDArray) -> None:\n        \"\"\"Set the 3D position of the base link.\"\"\"\n\n        if self.is_fixed_base():\n            raise ValueError(\"The position of a fixed-base model cannot be set.\")\n\n        position = np.atleast_1d(np.array(position).squeeze())\n\n        if position.size != 3:\n            raise ValueError(f\"Wrong position size ({position.size})\")\n\n        self.data.qpos[:3] = position\n\n    def set_base_orientation(self, orientation: npt.NDArray, dcm: bool = False) -> None:\n        \"\"\"Set the 3D position of the base link.\"\"\"\n\n        if self.is_fixed_base():\n            raise ValueError(\"The orientation of a fixed-base model cannot be set.\")\n\n        orientation = (\n            np.atleast_2d(np.array(orientation).squeeze())\n            if dcm\n            else np.atleast_1d(np.array(orientation).squeeze())\n        )\n\n        if orientation.shape != ((4,) if not dcm else (3, 3)):\n            raise ValueError(f\"Wrong orientation shape {orientation.shape}\")\n\n        def is_quaternion(Q):\n            return np.allclose(np.linalg.norm(Q), 1.0)\n\n        def is_dcm(R):\n            return np.allclose(np.linalg.det(R), 1.0) and np.allclose(\n                R.T @ R, np.eye(3)\n            )\n\n        if not (is_quaternion(orientation) if not dcm else is_dcm(orientation)):\n            raise ValueError(\"The orientation is not a valid element of SO(3)\")\n\n        W_Q_B = (\n            Rotation.from_matrix(orientation).as_quat(\n                canonical=True, scalar_first=False\n            )\n            if dcm\n            else orientation\n        )\n\n        self.data.qpos[3:7] = W_Q_B\n\n    # ==================\n    # Methods for joints\n    # ==================\n\n    def number_of_joints(self) -> int:\n        \"\"\"Return the number of joints in the model.\"\"\"\n\n        return self.model.njnt\n\n    def number_of_dofs(self) -> int:\n        \"\"\"Return the number of DoFs in the model.\"\"\"\n\n        return self.model.nq\n\n    def joint_names(self) -> list[str]:\n        \"\"\"Return the names of the joints in the model.\"\"\"\n\n        return [\n            mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_JOINT, idx)\n            for idx in range(0 if self.is_fixed_base() else 1, self.number_of_joints())\n        ]\n\n    def joint_dofs(self, joint_name: str) -> int:\n        \"\"\"Return the number of DoFs of a joint.\"\"\"\n\n        if joint_name not in self.joint_names():\n            raise ValueError(f\"Joint '{joint_name}' not found\")\n\n        return self.data.joint(joint_name).qpos.size\n\n    def joint_position(self, joint_name: str) -> npt.NDArray:\n        \"\"\"Return the position of a joint.\"\"\"\n\n        if joint_name not in self.joint_names():\n            raise ValueError(f\"Joint '{joint_name}' not found\")\n\n        return self.data.joint(joint_name).qpos\n\n    def joint_positions(self, joint_names: list[str] | None = None) -> npt.NDArray:\n        \"\"\"Return the positions of the joints.\"\"\"\n\n        joint_names = joint_names if joint_names is not None else self.joint_names()\n\n        return np.hstack(\n            [self.joint_position(joint_name) for joint_name in joint_names]\n        )\n\n    def set_joint_position(\n        self, joint_name: str, position: npt.NDArray | float\n    ) -> None:\n        \"\"\"Set the position of a joint.\"\"\"\n\n        position = np.atleast_1d(np.array(position).squeeze())\n\n        if position.size != self.joint_dofs(joint_name=joint_name):\n            raise ValueError(\n                f\"Wrong position size ({position.size}) of \"\n                f\"{self.joint_dofs(joint_name=joint_name)}-DoFs joint '{joint_name}'.\"\n            )\n\n        idx = mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, joint_name)\n        offset = self.model.jnt_qposadr[idx]\n\n        sl = np.s_[offset : offset + self.joint_dofs(joint_name=joint_name)]\n        self.data.qpos[sl] = position\n\n    def set_joint_positions(\n        self, joint_names: Sequence[str], positions: npt.NDArray | list[npt.NDArray]\n    ) -> None:\n        \"\"\"Set the positions of multiple joints.\"\"\"\n\n        mask = self.mask_qpos(joint_names=tuple(joint_names))\n        self.data.qpos[mask] = positions\n\n    # ==================\n    # Methods for bodies\n    # ==================\n\n    def number_of_bodies(self) -> int:\n        \"\"\"Return the number of bodies in the model.\"\"\"\n\n        return self.model.nbody\n\n    def body_names(self) -> list[str]:\n        \"\"\"Return the names of the bodies in the model.\"\"\"\n\n        return [\n            mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_BODY, idx)\n            for idx in range(self.number_of_bodies())\n        ]\n\n    def body_position(self, body_name: str) -> npt.NDArray:\n        \"\"\"Return the position of a body.\"\"\"\n\n        if body_name not in self.body_names():\n            raise ValueError(f\"Body '{body_name}' not found\")\n\n        return self.data.body(body_name).xpos\n\n    def body_orientation(self, body_name: str, dcm: bool = False) -> npt.NDArray:\n        \"\"\"Return the orientation of a body.\"\"\"\n\n        if body_name not in self.body_names():\n            raise ValueError(f\"Body '{body_name}' not found\")\n\n        return (\n            self.data.body(body_name).xmat if dcm else self.data.body(body_name).xquat\n        )\n\n    # ======================\n    # Methods for geometries\n    # ======================\n\n    def number_of_geometries(self) -> int:\n        \"\"\"Return the number of geometries in the model.\"\"\"\n\n        return self.model.ngeom\n\n    def geometry_names(self) -> list[str]:\n        \"\"\"Return the names of the geometries in the model.\"\"\"\n\n        return [\n            mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_GEOM, idx)\n            for idx in range(self.number_of_geometries())\n        ]\n\n    def geometry_position(self, geometry_name: str) -> npt.NDArray:\n        \"\"\"Return the position of a geometry.\"\"\"\n\n        if geometry_name not in self.geometry_names():\n            raise ValueError(f\"Geometry '{geometry_name}' not found\")\n\n        return self.data.geom(geometry_name).xpos\n\n    def geometry_orientation(\n        self, geometry_name: str, dcm: bool = False\n    ) -> npt.NDArray:\n        \"\"\"Return the orientation of a geometry.\"\"\"\n\n        if geometry_name not in self.geometry_names():\n            raise ValueError(f\"Geometry '{geometry_name}' not found\")\n\n        R = np.reshape(self.data.geom(geometry_name).xmat, newshape=(3, 3))\n\n        if dcm:\n            return R\n\n        q_xyzw = Rotation.from_matrix(R).as_quat(canonical=True, scalar_first=False)\n        return q_xyzw\n\n    # ===============\n    # Private methods\n    # ===============\n\n    def _mask_qpos(self, joint_names: tuple[str, ...]) -> npt.NDArray:\n        \"\"\"\n        Create a mask to access the DoFs of the desired `joint_names` in the `qpos` array.\n\n        Args:\n            joint_names: A tuple containing the names of the joints.\n\n        Returns:\n            A 1D array containing the indices of the `qpos` array to access the DoFs of\n            the desired `joint_names`.\n\n        Note:\n            This method takes a tuple of strings because we cache the output mask for\n            each combination of joint names. We need a hashable object for the cache.\n        \"\"\"\n\n        # Get the indices of the joints in `joint_names`.\n        idxs = [\n            mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, joint_name)\n            for joint_name in joint_names\n        ]\n\n        # We first get the index of each joint in the qpos array, and for those that\n        # have multiple DoFs, we expand their mask by appending new elements.\n        # Finally, we flatten the list of arrays to a single array, that is the\n        # final qpos mask accessing all the DoFs of the desired `joint_names`.\n        return np.atleast_1d(\n            np.hstack(\n                [\n                    np.array(\n                        [\n                            self.model.jnt_qposadr[idx] + i\n                            for i in range(self.joint_dofs(joint_name=joint_name))\n                        ]\n                    )\n                    for idx, joint_name in zip(idxs, joint_names, strict=True)\n                ]\n            ).squeeze()\n        )\n\n\ndef generate_hfield(\n    heightmap: HeightmapCallable,\n    samples_xy: tuple[int, int] = (11, 11),\n    radius_xy: tuple[float, float] = (1.0, 1.0),\n) -> npt.NDArray:\n    \"\"\"\n    Generate an array with elevation points sampled from a heightmap function.\n\n    The map will have the following format:\n    ```\n    heightmap[0, 0] heightmap[0, 1] ... heightmap[0, size[1]-1]\n    heightmap[1, 0] heightmap[1, 1] ... heightmap[1, size[1]-1]\n    ...\n    heightmap[size[0]-1, 0] heightmap[size[0]-1, 1] ... heightmap[size[0]-1, size[1]-1]\n    ```\n\n    Args:\n        heightmap:\n            A function that takes two arguments (x, y) and returns the height\n            at that point.\n        samples_xy: A tuple of two integers representing the size of the grid.\n        radius_xy:\n            A tuple of two floats representing extension of the heightmap in the\n            x-y surface corresponding to the area over which the grid of the sampled\n            heightmap is generated.\n\n    Returns:\n        A flat array of the sampled terrain heightmap.\n    \"\"\"\n\n    # Generate the grid.\n    x = np.linspace(-radius_xy[0], radius_xy[0], samples_xy[0])\n    y = np.linspace(-radius_xy[1], radius_xy[1], samples_xy[1])\n\n    # Generate the heightmap.\n    return np.array([[heightmap(xi, yi) for xi in x] for yi in y]).flatten()\n"
  },
  {
    "path": "src/jaxsim/mujoco/utils.py",
    "content": "from __future__ import annotations\n\nimport dataclasses\nfrom collections.abc import Sequence\n\nimport mujoco as mj\nimport numpy as np\nimport numpy.typing as npt\nfrom scipy.spatial.transform import Rotation\n\nfrom .model import MujocoModelHelper\n\n\ndef mujoco_data_from_jaxsim(\n    mujoco_model: mj.MjModel,\n    jaxsim_model,\n    jaxsim_data,\n    mujoco_data: mj.MjData | None = None,\n    update_removed_joints: bool = True,\n) -> mj.MjData:\n    \"\"\"\n    Create a Mujoco data object from a JaxSim model and data objects.\n\n    Args:\n        mujoco_model: The Mujoco model object corresponding to the JaxSim model.\n        jaxsim_model: The JaxSim model object from which the Mujoco model was created.\n        jaxsim_data: The JaxSim data object containing the state of the model.\n        mujoco_data: An optional Mujoco data object. If None, a new one will be created.\n        update_removed_joints:\n            If True, the positions of the joints that have been removed during the\n            model reduction process will be set to their initial values.\n\n    Returns:\n        The Mujoco data object containing the state of the JaxSim model.\n\n    Note:\n        This method is useful to initialize a Mujoco data object used for visualization\n        with the state of a JaxSim model. In particular, this function takes care of\n        initializing the positions of the joints that have been removed during the\n        model reduction process. After the initial creation of the Mujoco data object,\n        it's faster to update the state using an external MujocoModelHelper object.\n    \"\"\"\n\n    # The package `jaxsim.mujoco` is supposed to be jax-independent.\n    # We import all the JaxSim resources privately.\n    import jaxsim.api as js\n\n    if not isinstance(jaxsim_model, js.model.JaxSimModel):\n        raise ValueError(\"The `jaxsim_model` argument must be a JaxSimModel object.\")\n\n    if not isinstance(jaxsim_data, js.data.JaxSimModelData):\n        raise ValueError(\"The `jaxsim_data` argument must be a JaxSimModelData object.\")\n\n    # Create the helper to operate on the Mujoco model and data.\n    model_helper = MujocoModelHelper(model=mujoco_model, data=mujoco_data)\n\n    # If the model is fixed-base, the Mujoco model won't have the joint corresponding\n    # to the floating base, and the helper would raise an exception.\n    if jaxsim_model.floating_base():\n\n        # Set the model position.\n        model_helper.set_base_position(position=np.array(jaxsim_data.base_position))\n\n        # Set the model orientation.\n        model_helper.set_base_orientation(\n            orientation=np.array(jaxsim_data.base_orientation)\n        )\n\n    # Set the joint positions.\n    if jaxsim_model.dofs() > 0:\n\n        model_helper.set_joint_positions(\n            joint_names=list(jaxsim_model.joint_names()),\n            positions=np.array(jaxsim_data.joint_positions),\n        )\n\n    # Updating these joints is not necessary after the first time.\n    # Users can disable this update after initialization.\n    if update_removed_joints:\n\n        # Create a dictionary with the joints that have been removed for various reasons\n        # (like link lumping due to model reduction).\n        joints_removed_dict = {\n            j.name: j\n            for j in jaxsim_model.description._joints_removed\n            if j.name not in set(jaxsim_model.joint_names())\n        }\n\n        # Set the positions of the removed joints.\n        _ = [\n            model_helper.set_joint_position(\n                position=joints_removed_dict[joint_name].initial_position,\n                joint_name=joint_name,\n            )\n            # Select all original joint that have been removed from the JaxSim model\n            # that are still present in the Mujoco model.\n            for joint_name in joints_removed_dict\n            if joint_name in model_helper.joint_names()\n        ]\n\n    # Return the mujoco data with updated kinematics.\n    mj.mj_forward(mujoco_model, model_helper.data)\n\n    return model_helper.data\n\n\n@dataclasses.dataclass\nclass MujocoCamera:\n    \"\"\"\n    Helper class storing parameters of a Mujoco camera.\n\n    Refer to the official documentation for more details:\n    https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-camera\n    \"\"\"\n\n    mode: str = \"fixed\"\n\n    target: str | None = None\n    fovy: str = \"45\"\n    pos: str = \"0 0 0\"\n\n    quat: str | None = None\n    axisangle: str | None = None\n    xyaxes: str | None = None\n    zaxis: str | None = None\n    euler: str | None = None\n\n    name: str | None = None\n\n    @classmethod\n    def build(cls, **kwargs) -> MujocoCamera:\n        \"\"\"\n        Build a Mujoco camera from a dictionary.\n        \"\"\"\n\n        if not all(isinstance(value, str) for value in kwargs.values()):\n            raise ValueError(f\"Values must be strings: {kwargs}\")\n\n        return cls(**kwargs)\n\n    @staticmethod\n    def build_from_target_view(\n        camera_name: str,\n        mode: str = \"fixed\",\n        lookat: Sequence[float | int] | npt.NDArray = (0, 0, 0),\n        distance: float | int | npt.NDArray = 3,\n        azimuth: float | int | npt.NDArray = 90,\n        elevation: float | int | npt.NDArray = -45,\n        fovy: float | int | npt.NDArray = 45,\n        degrees: bool = True,\n        **kwargs,\n    ) -> MujocoCamera:\n        \"\"\"\n        Create a custom camera that looks at a target point.\n\n        Note:\n            The choice of the parameters is easier if we imagine to consider a target\n            frame `T` whose origin is located over the lookat point and having the same\n            orientation of the world frame `W`. We also introduce a camera frame `C`\n            whose origin is located over the lower-left corner of the image, and having\n            the x-axis pointing right and the y-axis pointing up in image coordinates.\n            The camera renders what it sees in the -z direction of frame `C`.\n\n        Args:\n            camera_name: The name of the camera.\n            mode: Camera positioning mode:\n                - **\"fixed\"**: Fixed position and orientation relative to the body.\n                - **\"track\"**: Fixed offset from the body in world coordinates, constant orientation.\n                - **\"trackcom\"**: Like `\"track\"`, but relative to the center of mass of the subtree.\n                - **\"targetbody\"**: Fixed position in body frame, oriented toward a target body.\n                - **\"targetbodycom\"**: Like `\"targetbody\"`, but targets the subtree's center of mass.\n            lookat: The target point to look at (origin of `T`).\n            distance:\n                The distance from the target point (displacement between the origins\n                of `T` and `C`).\n            azimuth:\n                The rotation around z of the camera. With an angle of 0, the camera\n                would loot at the target point towards the positive x-axis of `T`.\n            elevation:\n                The rotation around the x-axis of the camera frame `C`. Note that if\n                you want to lift the view angle, the elevation is negative.\n            fovy: The field of view of the camera.\n            degrees: Whether the angles are in degrees or radians.\n            **kwargs: Additional camera parameters.\n\n        Returns:\n            The custom camera.\n        \"\"\"\n\n        # Start from a frame whose origin is located over the lookat point.\n        # We initialize a -90 degrees rotation around the z-axis because due to\n        # the default camera coordinate system (x pointing right, y pointing up).\n        W_H_C = np.eye(4)\n        W_H_C[0:3, 3] = np.array(lookat)\n        W_H_C[0:3, 0:3] = Rotation.from_euler(\n            seq=\"ZX\", angles=[-90, 90], degrees=True\n        ).as_matrix()\n\n        # Process the azimuth.\n        R_az = Rotation.from_euler(seq=\"Y\", angles=azimuth, degrees=degrees).as_matrix()\n        W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_az\n\n        # Process elevation.\n        R_el = Rotation.from_euler(\n            seq=\"X\", angles=elevation, degrees=degrees\n        ).as_matrix()\n        W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_el\n\n        # Process distance.\n        tf_distance = np.eye(4)\n        tf_distance[2, 3] = distance\n        W_H_C = W_H_C @ tf_distance\n\n        # Extract the position and the quaternion.\n        p = W_H_C[0:3, 3]\n        Q = Rotation.from_matrix(W_H_C[0:3, 0:3]).as_quat(scalar_first=True)\n\n        return MujocoCamera.build(\n            name=camera_name,\n            mode=mode,\n            fovy=str(fovy if degrees else np.rad2deg(fovy)),\n            pos=\" \".join(p.astype(str).tolist()),\n            quat=\" \".join(Q.astype(str).tolist()),\n            **kwargs,\n        )\n\n    def asdict(self) -> dict[str, str]:\n        \"\"\"\n        Convert the camera to a dictionary.\n        \"\"\"\n        return {k: v for k, v in dataclasses.asdict(self).items() if v is not None}\n"
  },
  {
    "path": "src/jaxsim/mujoco/visualizer.py",
    "content": "import contextlib\nimport pathlib\nfrom collections.abc import Iterator, Sequence\n\nimport mediapy as media\nimport mujoco as mj\nimport mujoco.viewer\nimport numpy as np\nimport numpy.typing as npt\nfrom scipy.spatial.transform import Rotation\n\n\nclass MujocoVideoRecorder:\n    \"\"\"\n    Video recorder for the MuJoCo passive viewer.\n    \"\"\"\n\n    def __init__(\n        self,\n        model: list[mj.MjModel] | mj.MjModel,\n        data: list[mj.MjData] | mj.MjData,\n        fps: int = 30,\n        width: int | None = None,\n        height: int | None = None,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Initialize the Mujoco video recorder.\n\n        Args:\n            model: The Mujoco model.\n            data: The Mujoco data.\n            fps: The frames per second.\n            width: The width of the video.\n            height: The height of the video.\n            **kwargs: Additional arguments for the renderer.\n        \"\"\"\n\n        if isinstance(model, mj.MjModel):\n            single_model = model\n        elif isinstance(model, list) and len(model) == 1:\n            single_model = model[0]\n        else:\n            raise ValueError(\n                \"Model must be a single instance of mj.MjModel or a list with at least one element.\"\n            )\n\n        width = width if width is not None else single_model.vis.global_.offwidth\n        height = height if height is not None else single_model.vis.global_.offheight\n\n        if single_model.vis.global_.offwidth != width:\n            single_model.vis.global_.offwidth = width\n\n        if single_model.vis.global_.offheight != height:\n            single_model.vis.global_.offheight = height\n\n        self.fps = fps\n        self.frames: list[npt.NDArray] = []\n        self.data: list[mj.MjData] | mj.MjData | None = None\n        self.model: list[mj.MjModel] | mj.MjModel | None = None\n        self.reset(model=model, data=data)\n\n        self.renderer = mujoco.Renderer(\n            model=single_model,\n            **(dict(width=width, height=height) | kwargs),\n        )\n\n    def visualize_frame(\n        self, frame_pose: list[float] | npt.NDArray | None = None\n    ) -> None:\n        \"\"\"\n        Add visualization of a static frame.\n\n        Args:\n            frame_pose: The pose of a static frame to visualize as [x, y, z, roll, pitch, yaw].\n        \"\"\"\n\n        scene = self.renderer.scene\n\n        # Three free slots are needed for the axes (x, y, z).\n        if scene.ngeom + 3 > scene.maxgeom:\n            return\n\n        # Read position and RPY orientation\n        if not frame_pose:\n            return\n        try:\n            x, y, z, roll, pitch, yaw = frame_pose\n        except Exception as e:\n            raise ValueError(\n                \"Frame pose elements must be a 6D list: 'x y z roll pitch yaw'\"\n            ) from e\n\n        mat = Rotation.from_euler(\"xyz\", [roll, pitch, yaw], degrees=False).as_matrix()\n\n        origin = np.array([x, y, z])\n        length = 0.2  # length of axis cylinders\n        radius = 0.01  # slim radius for cylinders\n\n        for axis, color in zip(\n            range(3), [(1, 0, 0, 1), (0, 1, 0, 1), (0, 0, 1, 1)], strict=True\n        ):\n            if scene.ngeom >= scene.maxgeom:\n                break\n\n            axis_dir = mat[:, axis]\n\n            geom = scene.geoms[scene.ngeom]\n\n            # Cylinder position is centered at origin + half length along axis\n            pos = origin + axis_dir * length * 0.5\n\n            # Build rotation matrix for cylinder aligned with axis_dir\n            # MuJoCo's cylinder local axis is along z-axis\n            def rot_from_z(v: np.ndarray) -> np.ndarray:\n                v = v / np.linalg.norm(v)\n                z_axis = np.array([0, 0, 1])\n                if np.allclose(v, z_axis):\n                    return np.eye(3)\n                if np.allclose(v, -z_axis):\n                    return np.diag([1, -1, -1])\n                cross = np.cross(z_axis, v)\n                dot = np.dot(z_axis, v)\n                skew = np.array(\n                    [\n                        [0, -cross[2], cross[1]],\n                        [cross[2], 0, -cross[0]],\n                        [-cross[1], cross[0], 0],\n                    ]\n                )\n                R = np.eye(3) + skew + skew @ skew * (1 / (1 + dot))\n                return R\n\n            R = rot_from_z(axis_dir)\n            mat_flat = R.flatten()\n\n            mj.mjv_initGeom(\n                geom=geom,\n                type=mj.mjtGeom.mjGEOM_CYLINDER,\n                # The `size` arguments takes three positional arguments.\n                # In the cylinder case, the first two are the radius and half-length,\n                # and the third is not used (set to 0.0).\n                size=np.array([radius, length * 0.5, 0.0]),\n                rgba=np.array(color),\n                pos=pos,\n                mat=mat_flat,\n            )\n            geom.category = mj.mjtCatBit.mjCAT_STATIC\n\n            scene.ngeom += 1\n\n    def reset(\n        self,\n        model: mj.MjModel | None = None,\n        data: list[mj.MjData] | mj.MjData | None = None,\n    ) -> None:\n        \"\"\"Reset the model and data.\"\"\"\n\n        self.frames = []\n\n        self.data = data if data is not None else self.data\n        self.data = self.data if isinstance(self.data, list) else [self.data]\n\n        self.model = model if model is not None else self.model\n        self.model = self.model if isinstance(self.model, list) else [self.model]\n\n        assert len(self.data) == len(self.model) or len(self.model) == 1, (\n            f\"Length mismatch: len(data)={len(self.data)}, len(model)={len(self.model)}. \"\n            \"They must be equal or model must have length 1.\"\n        )\n\n    def render_frame(\n        self,\n        camera_name: str = \"track\",\n        frame_pose: list[float] | npt.NDArray | None = None,\n    ) -> npt.NDArray:\n        \"\"\"\n        Render a frame.\n\n        Args:\n            camera_name: The name of the camera to use for rendering.\n            frame_pose: The pose of a static frame to visualize as [x, y, z, roll, pitch, yaw].\n\n        Returns:\n            The rendered frame as a NumPy array.\n        \"\"\"\n\n        for idx, data in enumerate(self.data):\n\n            # Use a single model for rendering if multiple data instances are provided.\n            # Otherwise, use the data index to select the corresponding model.\n            model = self.model[0] if len(self.model) == 1 else self.model[idx]\n\n            mj.mj_forward(model, data)\n\n            if idx == 0:\n                self.renderer.update_scene(data=data, camera=camera_name)\n                self.visualize_frame(frame_pose=frame_pose)\n                continue\n\n            mujoco.mjv_addGeoms(\n                m=model,\n                d=data,\n                opt=mj.MjvOption(),\n                pert=mj.MjvPerturb(),\n                catmask=mj.mjtCatBit.mjCAT_DYNAMIC,\n                scn=self.renderer.scene,\n            )\n\n        return self.renderer.render()\n\n    def record_frame(\n        self,\n        camera_name: str = \"track\",\n        frame_pose: list[float] | npt.NDArray | None = None,\n    ) -> None:\n        \"\"\"Store a frame in the buffer.\"\"\"\n\n        frame = self.render_frame(camera_name=camera_name, frame_pose=frame_pose)\n        self.frames.append(frame)\n\n    def write_video(self, path: pathlib.Path | str, exist_ok: bool = False) -> None:\n        \"\"\"Write the video to a file.\"\"\"\n\n        # Resolve the path to the video.\n        path = pathlib.Path(path).expanduser().resolve()\n\n        if path.is_dir():\n            raise IsADirectoryError(f\"The path '{path}' is a directory.\")\n\n        if not exist_ok and path.is_file():\n            raise FileExistsError(f\"The file '{path}' already exists.\")\n\n        media.write_video(path=path, images=np.array(self.frames), fps=self.fps)\n\n    @staticmethod\n    def compute_down_sampling(original_fps: int, target_min_fps: int) -> int:\n        \"\"\"\n        Return the integer down-sampling factor to reach at least the target fps.\n\n        Args:\n            original_fps: The original fps.\n            target_min_fps: The target minimum fps.\n\n        Returns:\n            The down-sampling factor.\n        \"\"\"\n\n        down_sampling = 1\n        down_sampling_final = down_sampling\n\n        while original_fps / (down_sampling + 1) >= target_min_fps:\n            down_sampling = down_sampling + 1\n\n            if int(original_fps / down_sampling) == original_fps / down_sampling:\n                down_sampling_final = down_sampling\n\n        return down_sampling_final\n\n\nclass MujocoVisualizer:\n    \"\"\"\n    Visualizer for the MuJoCo passive viewer.\n    \"\"\"\n\n    def __init__(\n        self, model: mj.MjModel | None = None, data: mj.MjData | None = None\n    ) -> None:\n        \"\"\"\n        Initialize the Mujoco visualizer.\n\n        Args:\n            model: The Mujoco model.\n            data: The Mujoco data.\n        \"\"\"\n\n        self.data = data\n        self.model = model\n\n    def sync(\n        self,\n        viewer: mj.viewer.Handle,\n        model: mj.MjModel | None = None,\n        data: mj.MjData | None = None,\n    ) -> None:\n        \"\"\"Update the viewer with the current model and data.\"\"\"\n\n        data = data if data is not None else self.data\n        model = model if model is not None else self.model\n\n        mj.mj_forward(model, data)\n        viewer.sync()\n\n    def open_viewer(\n        self,\n        model: mj.MjModel | None = None,\n        data: mj.MjData | None = None,\n        show_left_ui: bool = False,\n    ) -> mj.viewer.Handle:\n        \"\"\"Open a viewer.\"\"\"\n\n        data = data if data is not None else self.data\n        model = model if model is not None else self.model\n\n        handle = mj.viewer.launch_passive(\n            model, data, show_left_ui=show_left_ui, show_right_ui=False\n        )\n\n        return handle\n\n    @contextlib.contextmanager\n    def open(\n        self,\n        model: mj.MjModel | None = None,\n        data: mj.MjData | None = None,\n        *,\n        show_left_ui: bool = False,\n        close_on_exit: bool = True,\n        lookat: Sequence[float | int] | npt.NDArray | None = None,\n        distance: float | int | npt.NDArray | None = None,\n        azimuth: float | int | npt.NDArray | None = None,\n        elevation: float | int | npt.NDArray | None = None,\n    ) -> Iterator[mj.viewer.Handle]:\n        \"\"\"\n        Context manager to open the Mujoco passive viewer.\n\n        Note:\n            Refer to the Mujoco documentation for details of the camera options:\n            https://mujoco.readthedocs.io/en/stable/XMLreference.html#visual-global\n        \"\"\"\n\n        handle = self.open_viewer(model=model, data=data, show_left_ui=show_left_ui)\n\n        handle = MujocoVisualizer.setup_viewer_camera(\n            viewer=handle,\n            lookat=lookat,\n            distance=distance,\n            azimuth=azimuth,\n            elevation=elevation,\n        )\n\n        try:\n            yield handle\n        finally:\n            _ = handle.close() if close_on_exit else None\n\n    @staticmethod\n    def setup_viewer_camera(\n        viewer: mj.viewer.Handle,\n        *,\n        lookat: Sequence[float | int] | npt.NDArray | None,\n        distance: float | int | npt.NDArray | None = None,\n        azimuth: float | int | npt.NDArray | None = None,\n        elevation: float | int | npt.NDArray | None = None,\n    ) -> mj.viewer.Handle:\n        \"\"\"\n        Configure the initial viewpoint of the Mujoco passive viewer.\n\n        Note:\n            Refer to the Mujoco documentation for details of the camera options:\n            https://mujoco.readthedocs.io/en/stable/XMLreference.html#visual-global\n\n        Returns:\n            The viewer with configured camera.\n        \"\"\"\n\n        if lookat is not None:\n\n            lookat_array = np.array(lookat, dtype=float).squeeze()\n\n            if lookat_array.size != 3:\n                raise ValueError(lookat)\n\n            viewer.cam.lookat = lookat_array\n\n        if distance is not None:\n            viewer.cam.distance = float(distance)\n\n        if azimuth is not None:\n            viewer.cam.azimuth = float(azimuth) % 360\n\n        if elevation is not None:\n            viewer.cam.elevation = float(elevation)\n\n        return viewer\n"
  },
  {
    "path": "src/jaxsim/parsers/__init__.py",
    "content": ""
  },
  {
    "path": "src/jaxsim/parsers/descriptions/__init__.py",
    "content": "from .collision import (\n    BoxCollision,\n    CollidablePoint,\n    CollisionShape,\n    MeshCollision,\n    SphereCollision,\n)\nfrom .joint import JointDescription, JointGenericAxis, JointType\nfrom .link import LinkDescription\nfrom .model import ModelDescription\n"
  },
  {
    "path": "src/jaxsim/parsers/descriptions/collision.py",
    "content": "from __future__ import annotations\n\nimport abc\nimport dataclasses\n\nimport jax.numpy as jnp\nimport numpy as np\nimport numpy.typing as npt\n\nimport jaxsim.typing as jtp\nfrom jaxsim import logging\n\nfrom .link import LinkDescription\n\n\n@dataclasses.dataclass\nclass CollidablePoint:\n    \"\"\"\n    Represents a collidable point associated with a parent link.\n\n    Attributes:\n        parent_link: The parent link to which the collidable point is attached.\n        position: The position of the collidable point relative to the parent link.\n        enabled: A flag indicating whether the collidable point is enabled for collision detection.\n    \"\"\"\n\n    parent_link: LinkDescription\n    position: npt.NDArray = dataclasses.field(default_factory=lambda: np.zeros(3))\n    enabled: bool = True\n\n    def change_link(\n        self, new_link: LinkDescription, new_H_old: npt.NDArray\n    ) -> CollidablePoint:\n        \"\"\"\n        Move the collidable point to a new parent link.\n\n        Args:\n            new_link (LinkDescription): The new parent link to which the collidable point is moved.\n            new_H_old (npt.NDArray): The transformation matrix from the new link's frame to the old link's frame.\n\n        Returns:\n            CollidablePoint: A new collidable point associated with the new parent link.\n        \"\"\"\n\n        msg = f\"Moving collidable point: {self.parent_link.name} -> {new_link.name}\"\n        logging.debug(msg=msg)\n\n        return CollidablePoint(\n            parent_link=new_link,\n            position=(new_H_old @ jnp.hstack([self.position, 1.0])).squeeze()[0:3],\n            enabled=self.enabled,\n        )\n\n    def __hash__(self) -> int:\n\n        return hash(\n            (\n                hash(self.parent_link),\n                hash(tuple(self.position.tolist())),\n                hash(self.enabled),\n            )\n        )\n\n    def __eq__(self, other: CollidablePoint) -> bool:\n\n        if not isinstance(other, CollidablePoint):\n            return False\n\n        return hash(self) == hash(other)\n\n    def __str__(self) -> str:\n        return (\n            f\"{self.__class__.__name__}(\"\n            + f\"parent_link={self.parent_link.name}\"\n            + f\", position={self.position}\"\n            + f\", enabled={self.enabled}\"\n            + \")\"\n        )\n\n\n@dataclasses.dataclass\nclass CollisionShape(abc.ABC):\n    \"\"\"\n    Abstract base class for representing collision shapes.\n\n    Attributes:\n        collidable_points: A list of collidable points associated with the collision shape.\n    \"\"\"\n\n    collidable_points: tuple[CollidablePoint]\n\n    def __str__(self):\n        return (\n            f\"{self.__class__.__name__}(\"\n            + \"collidable_points=[\\n    \"\n            + \",\\n    \".join(str(cp) for cp in self.collidable_points)\n            + \"\\n])\"\n        )\n\n\n@dataclasses.dataclass\nclass BoxCollision(CollisionShape):\n    \"\"\"\n    Represents a box-shaped collision shape.\n\n    Attributes:\n        center: The center of the box in the local frame of the collision shape.\n    \"\"\"\n\n    center: jtp.VectorLike\n\n    def __hash__(self) -> int:\n        return hash(\n            (\n                hash(super()),\n                hash(tuple(self.center.tolist())),\n            )\n        )\n\n    def __eq__(self, other: BoxCollision) -> bool:\n\n        if not isinstance(other, BoxCollision):\n            return False\n\n        return hash(self) == hash(other)\n\n\n@dataclasses.dataclass\nclass SphereCollision(CollisionShape):\n    \"\"\"\n    Represents a spherical collision shape.\n\n    Attributes:\n        center: The center of the sphere in the local frame of the collision shape.\n    \"\"\"\n\n    center: jtp.VectorLike\n\n    def __hash__(self) -> int:\n        return hash(\n            (\n                hash(super()),\n                hash(tuple(self.center.tolist())),\n            )\n        )\n\n    def __eq__(self, other: BoxCollision) -> bool:\n\n        if not isinstance(other, BoxCollision):\n            return False\n\n        return hash(self) == hash(other)\n\n\n@dataclasses.dataclass\nclass MeshCollision(CollisionShape):\n    \"\"\"\n    Represents a mesh-shaped collision shape.\n\n    Attributes:\n        center: The center of the mesh in the local frame of the collision shape.\n    \"\"\"\n\n    center: jtp.VectorLike\n\n    def __hash__(self) -> int:\n        return hash(\n            (\n                hash(tuple(self.center.tolist())),\n                hash(self.collidable_points),\n            )\n        )\n\n    def __eq__(self, other: MeshCollision) -> bool:\n        if not isinstance(other, MeshCollision):\n            return False\n\n        return hash(self) == hash(other)\n"
  },
  {
    "path": "src/jaxsim/parsers/descriptions/joint.py",
    "content": "from __future__ import annotations\n\nimport dataclasses\nfrom typing import ClassVar\n\nimport jax_dataclasses\nimport numpy as np\n\nimport jaxsim.typing as jtp\nfrom jaxsim.utils import JaxsimDataclass, Mutability\n\nfrom .link import LinkDescription\n\n\n@dataclasses.dataclass(frozen=True)\nclass JointType:\n    \"\"\"\n    Enumeration of joint types.\n    \"\"\"\n\n    Fixed: ClassVar[int] = 0\n    Revolute: ClassVar[int] = 1\n    Prismatic: ClassVar[int] = 2\n\n\n@jax_dataclasses.pytree_dataclass\nclass JointGenericAxis:\n    \"\"\"\n    A joint requiring the specification of a 3D axis.\n    \"\"\"\n\n    # The axis of rotation or translation of the joint (must have norm 1).\n    axis: jtp.Vector\n\n    def __hash__(self) -> int:\n\n        return hash(tuple(self.axis.tolist()))\n\n    def __eq__(self, other: JointGenericAxis) -> bool:\n\n        if not isinstance(other, JointGenericAxis):\n            return False\n\n        return hash(self) == hash(other)\n\n\n@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)\nclass JointDescription(JaxsimDataclass):\n    \"\"\"\n    In-memory description of a robot link.\n\n    Attributes:\n        name: The name of the joint.\n        axis: The axis of rotation or translation for the joint.\n        pose: The pose transformation matrix of the joint.\n        jtype: The type of the joint.\n        child: The child link attached to the joint.\n        parent: The parent link attached to the joint.\n        index: An optional index for the joint.\n        friction_static: The static friction coefficient for the joint.\n        friction_viscous: The viscous friction coefficient for the joint.\n        position_limit_damper: The damper coefficient for position limits.\n        position_limit_spring: The spring coefficient for position limits.\n        position_limit: The position limits for the joint.\n        initial_position: The initial position of the joint.\n    \"\"\"\n\n    name: jax_dataclasses.Static[str]\n    axis: jtp.Vector\n    pose: jtp.Matrix\n    jtype: jax_dataclasses.Static[jtp.IntLike]\n    child: LinkDescription = dataclasses.dataclass(repr=False)\n    parent: LinkDescription = dataclasses.dataclass(repr=False)\n\n    index: jtp.IntLike | None = None\n\n    friction_static: jtp.FloatLike = 0.0\n    friction_viscous: jtp.FloatLike = 0.0\n\n    position_limit_damper: jtp.FloatLike = 0.0\n    position_limit_spring: jtp.FloatLike = 0.0\n\n    position_limit: tuple[jtp.FloatLike, jtp.FloatLike] = (0.0, 0.0)\n    initial_position: jtp.FloatLike | jtp.VectorLike = 0.0\n\n    motor_inertia: jtp.FloatLike = 0.0\n    motor_viscous_friction: jtp.FloatLike = 0.0\n    motor_gear_ratio: jtp.FloatLike = 1.0\n\n    def __post_init__(self) -> None:\n\n        if self.axis is not None:\n\n            with self.mutable_context(\n                mutability=Mutability.MUTABLE, restore_after_exception=False\n            ):\n                norm_of_axis = np.linalg.norm(self.axis)\n                self.axis = self.axis / norm_of_axis\n\n    def __eq__(self, other: JointDescription) -> bool:\n\n        if not isinstance(other, JointDescription):\n            return False\n\n        return hash(self) == hash(other)\n\n    def __hash__(self) -> int:\n\n        from jaxsim.utils.wrappers import HashedNumpyArray\n\n        return hash(\n            (\n                hash(self.name),\n                HashedNumpyArray.hash_of_array(self.axis),\n                HashedNumpyArray.hash_of_array(self.pose),\n                hash(int(self.jtype)),\n                hash(self.child),\n                hash(self.parent),\n                hash(int(self.index)) if self.index is not None else 0,\n                HashedNumpyArray.hash_of_array(self.friction_static),\n                HashedNumpyArray.hash_of_array(self.friction_viscous),\n                HashedNumpyArray.hash_of_array(self.position_limit_damper),\n                HashedNumpyArray.hash_of_array(self.position_limit_spring),\n                HashedNumpyArray.hash_of_array(self.position_limit),\n                HashedNumpyArray.hash_of_array(self.initial_position),\n                HashedNumpyArray.hash_of_array(self.motor_inertia),\n                HashedNumpyArray.hash_of_array(self.motor_viscous_friction),\n                HashedNumpyArray.hash_of_array(self.motor_gear_ratio),\n            ),\n        )\n"
  },
  {
    "path": "src/jaxsim/parsers/descriptions/link.py",
    "content": "from __future__ import annotations\n\nimport dataclasses\n\nimport jax.numpy as jnp\nimport jax_dataclasses\nimport numpy as np\nfrom jax_dataclasses import Static\n\nimport jaxsim.typing as jtp\nfrom jaxsim.math import Adjoint\nfrom jaxsim.utils import JaxsimDataclass\n\n\n@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)\nclass LinkDescription(JaxsimDataclass):\n    \"\"\"\n    In-memory description of a robot link.\n\n    Attributes:\n        name: The name of the link.\n        mass: The mass of the link.\n        inertia: The inertia tensor of the link.\n        index: An optional index for the link (it gets automatically assigned).\n        parent: The parent link of this link.\n        pose: The pose transformation matrix of the link.\n        children: The children links.\n    \"\"\"\n\n    name: Static[str]\n    mass: float = dataclasses.field(repr=False)\n    inertia: jtp.Matrix = dataclasses.field(repr=False)\n    index: int | None = None\n    parent_name: Static[str | None] = dataclasses.field(default=None, repr=False)\n    pose: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.eye(4), repr=False)\n\n    children: Static[tuple[LinkDescription]] = dataclasses.field(\n        default_factory=list, repr=False\n    )\n\n    def __hash__(self) -> int:\n\n        from jaxsim.utils.wrappers import HashedNumpyArray\n\n        return hash(\n            (\n                hash(self.name),\n                hash(float(self.mass)),\n                HashedNumpyArray.hash_of_array(self.inertia),\n                hash(int(self.index)) if self.index is not None else 0,\n                HashedNumpyArray.hash_of_array(self.pose),\n                hash(tuple(self.children)),\n                hash(self.parent_name) if self.parent_name is not None else 0,\n            )\n        )\n\n    def __eq__(self, other: LinkDescription) -> bool:\n\n        if not isinstance(other, LinkDescription):\n            return False\n\n        if not (\n            self.name == other.name\n            and np.allclose(self.mass, other.mass)\n            and np.allclose(self.inertia, other.inertia)\n            and self.index == other.index\n            and np.allclose(self.pose, other.pose)\n            and self.children == other.children\n            and self.parent_name == other.parent_name\n        ):\n            return False\n\n        return True\n\n    @property\n    def name_and_index(self) -> str:\n        \"\"\"\n        Get a formatted string with the link's name and index.\n\n        Returns:\n            str: The formatted string.\n\n        \"\"\"\n        return f\"#{self.index}_<{self.name}>\"\n\n    def lump_with(\n        self, link: LinkDescription, lumped_H_removed: jtp.Matrix\n    ) -> LinkDescription:\n        \"\"\"\n        Combine the current link with another link, preserving mass and inertia.\n\n        Args:\n            link: The link to combine with.\n            lumped_H_removed: The transformation matrix between the two links.\n\n        Returns:\n            The combined link.\n        \"\"\"\n\n        # Get the 6D inertia of the link to remove.\n        I_removed = link.inertia\n\n        # Create the SE3 object. Note the inverse.\n        r_X_l = Adjoint.from_transform(transform=lumped_H_removed, inverse=True)\n\n        # Move the inertia\n        I_removed_in_lumped_frame = r_X_l.transpose() @ I_removed @ r_X_l\n\n        # Create the new combined link\n        lumped_link = self.replace(\n            mass=self.mass + link.mass,\n            inertia=self.inertia + I_removed_in_lumped_frame,\n        )\n\n        return lumped_link\n"
  },
  {
    "path": "src/jaxsim/parsers/descriptions/model.py",
    "content": "from __future__ import annotations\n\nimport dataclasses\nimport itertools\nfrom collections.abc import Sequence\n\nfrom jaxsim import logging\nfrom jaxsim.logging import jaxsim_warn\n\nfrom ..kinematic_graph import KinematicGraph, KinematicGraphTransforms, RootPose\nfrom .collision import CollidablePoint, CollisionShape\nfrom .joint import JointDescription\nfrom .link import LinkDescription\n\n\n@dataclasses.dataclass(frozen=True, eq=False, unsafe_hash=False)\nclass ModelDescription(KinematicGraph):\n    \"\"\"\n    Intermediate representation representing the kinematic graph of a robot model.\n\n    Attributes:\n        name: The name of the model.\n        fixed_base: Whether the model is either fixed-base or floating-base.\n        collision_shapes: List of collision shapes associated with the model.\n    \"\"\"\n\n    name: str = None\n\n    fixed_base: bool = True\n\n    collision_shapes: tuple[CollisionShape, ...] = dataclasses.field(\n        default_factory=list, repr=False\n    )\n\n    @staticmethod\n    def build_model_from(\n        name: str,\n        links: list[LinkDescription],\n        joints: list[JointDescription],\n        frames: list[LinkDescription] | None = None,\n        collisions: tuple[CollisionShape, ...] = (),\n        fixed_base: bool = False,\n        base_link_name: str | None = None,\n        considered_joints: Sequence[str] | None = None,\n        model_pose: RootPose = RootPose(),\n    ) -> ModelDescription:\n        \"\"\"\n        Build a model description from provided components.\n\n        Args:\n            name: The name of the model.\n            links: List of link descriptions.\n            joints: List of joint descriptions.\n            frames: List of frame descriptions.\n            collisions: List of collision shapes associated with the model.\n            fixed_base: Indicates whether the model has a fixed base.\n            base_link_name: Name of the base link (i.e. the root of the kinematic tree).\n            considered_joints: List of joint names to consider (by default all joints).\n            model_pose: Pose of the model's root (by default an identity transform).\n\n        Returns:\n            A ModelDescription instance representing the model.\n        \"\"\"\n\n        # Create the full kinematic graph.\n        kinematic_graph = KinematicGraph.build_from(\n            links=links,\n            joints=joints,\n            frames=frames,\n            root_link_name=base_link_name,\n            root_pose=model_pose,\n        )\n\n        # Reduce the graph if needed.\n        if considered_joints is not None:\n            kinematic_graph = kinematic_graph.reduce(\n                considered_joints=considered_joints\n            )\n\n        # Create the object to compute forward kinematics.\n        fk = KinematicGraphTransforms(graph=kinematic_graph)\n\n        # Container of the final model's collision shapes.\n        final_collisions: list[CollisionShape] = []\n\n        # Move and express the collision shapes of removed links to the resulting\n        # lumped link that replace the combination of the removed link and its parent.\n        for collision_shape in collisions:\n\n            # Assume they have an unique parent link\n            if (\n                len({cp.parent_link.name for cp in collision_shape.collidable_points})\n                != 1\n            ):\n                msg = \"Collision shape not currently supported (multiple parent links)\"\n                raise RuntimeError(msg)\n\n            # Get the parent link of the collision shape.\n            # Note that this link could have been lumped and we need to find the\n            # link in which it was lumped into.\n            parent_link_of_shape = collision_shape.collidable_points[0].parent_link\n\n            # If it is part of the (reduced) graph, add it as it is...\n            if parent_link_of_shape.name in kinematic_graph.link_names():\n                final_collisions.append(collision_shape)\n                continue\n\n            # ... otherwise look for the frame\n            if parent_link_of_shape.name not in kinematic_graph.frame_names():\n                msg = \"Parent frame '{}' of collision shape not found, ignoring shape\"\n                logging.info(msg.format(parent_link_of_shape.name))\n                continue\n\n            # Create a new collision shape\n            new_collision_shape = CollisionShape(collidable_points=())\n            final_collisions.append(new_collision_shape)\n\n            # If the frame was found, update the collidable points' pose and add them\n            # to the new collision shape.\n            for cp in collision_shape.collidable_points:\n                # Find the link that is part of the (reduced) model in which the\n                # collision shape's parent was lumped into\n                real_parent_link_name = kinematic_graph.frames_dict[\n                    parent_link_of_shape.name\n                ].parent_name\n\n                # Change the link associated to the collidable point, updating their\n                # relative pose\n                moved_cp = cp.change_link(\n                    new_link=kinematic_graph.links_dict[real_parent_link_name],\n                    new_H_old=fk.relative_transform(\n                        relative_to=real_parent_link_name,\n                        name=cp.parent_link.name,\n                    ),\n                )\n\n                # Store the updated collision.\n                new_collision_shape.collidable_points += (moved_cp,)\n\n        # Build the model\n        model = ModelDescription(\n            name=name,\n            root_pose=kinematic_graph.root_pose,\n            fixed_base=fixed_base,\n            collision_shapes=tuple(final_collisions),\n            root=kinematic_graph.root,\n            joints=kinematic_graph.joints,\n            frames=kinematic_graph.frames,\n            _joints_removed=kinematic_graph.joints_removed,\n        )\n\n        # Check that the root link of kinematic graph is the desired base link.\n        assert kinematic_graph.root.name == base_link_name, kinematic_graph.root.name\n\n        return model\n\n    def reduce(self, considered_joints: Sequence[str]) -> ModelDescription:\n        \"\"\"\n        Reduce the model by removing specified joints.\n\n        Args:\n            considered_joints: Sequence of joint names to consider.\n\n        Returns:\n            A `ModelDescription` instance that only includes the considered joints.\n        \"\"\"\n\n        jaxsim_warn(\n            \"The joint order in the model description is not preserved when reducing \"\n            \"the model. Consider using the `names_to_indices` method to get the correct \"\n            \"order of the joints, or use the `joint_names()` method to inspect the internal joint ordering.\"\n        )\n\n        if set(considered_joints) - set(self.joint_names()):\n            extra_joints = set(considered_joints) - set(self.joint_names())\n            msg = f\"Found joints not part of the model: {extra_joints}\"\n            raise ValueError(msg)\n\n        reduced_model_description = ModelDescription.build_model_from(\n            name=self.name,\n            links=list(self.links_dict.values()),\n            joints=self.joints,\n            frames=self.frames,\n            collisions=self.collision_shapes,\n            fixed_base=self.fixed_base,\n            base_link_name=next(iter(self)).name,\n            model_pose=self.root_pose,\n            considered_joints=considered_joints,\n        )\n\n        # Include the unconnected/removed joints from the original model.\n        for joint in self.joints_removed:\n            reduced_model_description.joints_removed.append(joint)\n\n        return reduced_model_description\n\n    def update_collision_shape_of_link(self, link_name: str, enabled: bool) -> None:\n        \"\"\"\n        Enable or disable collision shapes associated with a link.\n\n        Args:\n            link_name: The name of the link.\n            enabled: Enable or disable collision shapes associated with the link.\n        \"\"\"\n\n        if link_name not in self.link_names():\n            raise ValueError(link_name)\n\n        for point in self.collision_shape_of_link(\n            link_name=link_name\n        ).collidable_points:\n            point.enabled = enabled\n\n    def collision_shape_of_link(self, link_name: str) -> CollisionShape:\n        \"\"\"\n        Get the collision shape associated with a specific link.\n\n        Args:\n            link_name: The name of the link.\n\n        Returns:\n            The collision shape associated with the link.\n        \"\"\"\n\n        if link_name not in self.link_names():\n            raise ValueError(link_name)\n\n        return CollisionShape(\n            collidable_points=[\n                point\n                for shape in self.collision_shapes\n                for point in shape.collidable_points\n                if point.parent_link.name == link_name\n            ]\n        )\n\n    def all_enabled_collidable_points(self) -> list[CollidablePoint]:\n        \"\"\"\n        Get all enabled collidable points in the model.\n\n        Returns:\n            The list of all enabled collidable points.\n\n        \"\"\"\n\n        # Get iterator of all collidable points\n        all_collidable_points = itertools.chain.from_iterable(\n            [shape.collidable_points for shape in self.collision_shapes]\n        )\n\n        # Return enabled collidable points\n        return [cp for cp in all_collidable_points if cp.enabled]\n\n    def __eq__(self, other: ModelDescription) -> bool:\n\n        if not isinstance(other, ModelDescription):\n            return False\n\n        if not (\n            self.name == other.name\n            and self.fixed_base == other.fixed_base\n            and self.root == other.root\n            and self.joints == other.joints\n            and self.frames == other.frames\n            and self.root_pose == other.root_pose\n        ):\n            return False\n\n        return True\n\n    def __hash__(self) -> int:\n\n        return hash(\n            (\n                hash(self.name),\n                hash(self.fixed_base),\n                hash(self.root),\n                hash(tuple(self.joints)),\n                hash(tuple(self.frames)),\n                hash(self.root_pose),\n            )\n        )\n"
  },
  {
    "path": "src/jaxsim/parsers/kinematic_graph.py",
    "content": "from __future__ import annotations\n\nimport copy\nimport dataclasses\nimport functools\nfrom collections.abc import Callable, Iterable, Iterator, Sequence\nfrom typing import Any\n\nimport numpy as np\nimport numpy.typing as npt\n\nimport jaxsim.utils\nfrom jaxsim import logging\nfrom jaxsim.utils import Mutability\n\nfrom .descriptions.joint import JointDescription, JointType\nfrom .descriptions.link import LinkDescription\n\n\n@dataclasses.dataclass\nclass RootPose:\n    \"\"\"\n    Represents the root pose in a kinematic graph.\n\n    Attributes:\n        root_position: The 3D position of the root link of the graph.\n        root_quaternion:\n            The quaternion representing the rotation of the root link of the graph.\n\n    Note:\n        The root link of the kinematic graph is the base link.\n    \"\"\"\n\n    root_position: npt.NDArray = dataclasses.field(default_factory=lambda: np.zeros(3))\n\n    root_quaternion: npt.NDArray = dataclasses.field(\n        default_factory=lambda: np.array([1.0, 0, 0, 0])\n    )\n\n    def __hash__(self) -> int:\n\n        from jaxsim.utils.wrappers import HashedNumpyArray\n\n        return hash(\n            (\n                HashedNumpyArray.hash_of_array(self.root_position),\n                HashedNumpyArray.hash_of_array(self.root_quaternion),\n            )\n        )\n\n    def __eq__(self, other: RootPose) -> bool:\n\n        if not isinstance(other, RootPose):\n            return False\n\n        if not np.allclose(self.root_position, other.root_position):\n            return False\n\n        if not np.allclose(self.root_quaternion, other.root_quaternion):\n            return False\n\n        return True\n\n\n@dataclasses.dataclass(frozen=True)\nclass KinematicGraph(Sequence[LinkDescription]):\n    \"\"\"\n    Class storing a kinematic graph having links as nodes and joints as edges.\n\n    Attributes:\n        root: The root node of the kinematic graph.\n        frames: List of frames rigidly attached to the graph nodes.\n        joints: List of joints connecting the graph nodes.\n        root_pose: The pose of the kinematic graph's root.\n    \"\"\"\n\n    root: LinkDescription\n    frames: list[LinkDescription] = dataclasses.field(\n        default_factory=list, hash=False, compare=False\n    )\n    joints: list[JointDescription] = dataclasses.field(\n        default_factory=list, hash=False, compare=False\n    )\n\n    root_pose: RootPose = dataclasses.field(default_factory=RootPose)\n\n    # Private attribute storing optional additional info.\n    _extra_info: dict[str, Any] = dataclasses.field(\n        default_factory=dict, repr=False, hash=False, compare=False\n    )\n\n    # Private attribute storing the unconnected joints from the parsed model and\n    # the joints removed after model reduction.\n    _joints_removed: list[JointDescription] = dataclasses.field(\n        default_factory=list, repr=False, hash=False, compare=False\n    )\n\n    @functools.cached_property\n    def links_dict(self) -> dict[str, LinkDescription]:\n        \"\"\"\n        Get a dictionary of links indexed by their name.\n        \"\"\"\n        return {l.name: l for l in iter(self)}\n\n    @functools.cached_property\n    def frames_dict(self) -> dict[str, LinkDescription]:\n        \"\"\"\n        Get a dictionary of frames indexed by their name.\n        \"\"\"\n        return {f.name: f for f in self.frames}\n\n    @functools.cached_property\n    def joints_dict(self) -> dict[str, JointDescription]:\n        \"\"\"\n        Get a dictionary of joints indexed by their name.\n        \"\"\"\n        return {j.name: j for j in self.joints}\n\n    @functools.cached_property\n    def joints_connection_dict(\n        self,\n    ) -> dict[tuple[str, str], JointDescription]:\n        \"\"\"\n        Get a dictionary of joints indexed by the tuple (parent, child) link names.\n        \"\"\"\n        return {(j.parent.name, j.child.name): j for j in self.joints}\n\n    def __post_init__(self) -> None:\n\n        # Assign the link index by traversing the graph with BFS.\n        # Here we assume the model being fixed-base, therefore the base link will\n        # have index 0. We will deal with the floating base in a later stage.\n        for index, link in enumerate(self):\n            link.mutable(validate=False).index = index\n\n        # Get the names of the links, frames, and joints.\n        link_names = [l.name for l in self]\n        frame_names = [f.name for f in self.frames]\n        joint_names = [j.name for j in self.joints]\n\n        # Make sure that they are unique.\n        assert len(link_names) == len(set(link_names))\n        assert len(frame_names) == len(set(frame_names))\n        assert len(joint_names) == len(set(joint_names))\n        assert set(link_names).isdisjoint(set(frame_names))\n        assert set(link_names).isdisjoint(set(joint_names))\n\n        # Order frames with their name.\n        super().__setattr__(\"frames\", sorted(self.frames, key=lambda f: f.name))\n\n        # Assign the frame index following the name-based indexing.\n        # We assume the model being fixed-base, therefore the first frame will\n        # have last_link_idx + 1.\n        for index, frame in enumerate(self.frames):\n            with frame.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):\n                frame.index = index + len(self.link_names())\n\n        # Number joints so that their index matches their child link index.\n        # Therefore, the first joint has index 1.\n        links_dict = {l.name: l for l in iter(self)}\n        for joint in self.joints:\n            with joint.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):\n                joint.index = links_dict[joint.child.name].index\n\n        # Check that joint indices are unique.\n        assert len([j.index for j in self.joints]) == len(\n            {j.index for j in self.joints}\n        )\n\n        # Order joints with their indices.\n        super().__setattr__(\"joints\", sorted(self.joints, key=lambda j: j.index))\n\n    @staticmethod\n    def build_from(\n        links: list[LinkDescription],\n        joints: list[JointDescription],\n        frames: list[LinkDescription] | None = None,\n        root_link_name: str | None = None,\n        root_pose: RootPose = RootPose(),\n    ) -> KinematicGraph:\n        \"\"\"\n        Build a KinematicGraph from links, joints, and frames.\n\n        Args:\n            links: A list of link descriptions.\n            joints: A list of joint descriptions.\n            frames: A list of frame descriptions.\n            root_link_name:\n                The name of the root link. If not provided, it's assumed to be the\n                first link's name.\n            root_pose: The root pose of the kinematic graph.\n\n        Returns:\n            The resulting kinematic graph.\n        \"\"\"\n\n        # Consider the first link as the root link if not provided.\n        if root_link_name is None:\n            root_link_name = links[0].name\n            logging.debug(msg=f\"Assuming '{root_link_name}' as the root link\")\n\n        # Couple links and joints and create the graph of links.\n        # Note that the pose of the frames is not updated; it is the caller's\n        # responsibility to update their pose if they want to use them.\n        (\n            graph_root_node,\n            graph_joints,\n            graph_frames,\n            unconnected_links,\n            unconnected_joints,\n            unconnected_frames,\n        ) = KinematicGraph._create_graph(\n            links=links, joints=joints, root_link_name=root_link_name, frames=frames\n        )\n\n        for link in unconnected_links:\n            logging.warning(msg=f\"Ignoring unconnected link: '{link.name}'\")\n\n        for joint in unconnected_joints:\n            logging.warning(msg=f\"Ignoring unconnected joint: '{joint.name}'\")\n\n        for frame in unconnected_frames:\n            logging.warning(msg=f\"Ignoring unconnected frame: '{frame.name}'\")\n\n        return KinematicGraph(\n            root=graph_root_node,\n            joints=graph_joints,\n            frames=graph_frames,\n            root_pose=root_pose,\n            _joints_removed=unconnected_joints,\n        )\n\n    @staticmethod\n    def _create_graph(\n        links: list[LinkDescription],\n        joints: list[JointDescription],\n        root_link_name: str,\n        frames: list[LinkDescription] | None = None,\n    ) -> tuple[\n        LinkDescription,\n        list[JointDescription],\n        list[LinkDescription],\n        list[LinkDescription],\n        list[JointDescription],\n        list[LinkDescription],\n    ]:\n        \"\"\"\n        Low-level creator of kinematic graph components.\n\n        Args:\n            links: A list of parsed link descriptions.\n            joints: A list of parsed joint descriptions.\n            root_link_name: The name of the root link used as root node of the graph.\n            frames: A list of parsed frame descriptions.\n\n        Returns:\n            A tuple containing the root node of the graph (defining the entire kinematic\n            tree by iterating on its child nodes), the list of joints representing the\n            actual graph edges, the list of frames rigidly attached to the graph nodes,\n            the list of unconnected links, the list of unconnected joints, and the list\n            of unconnected frames.\n        \"\"\"\n\n        # Create a dictionary that maps the link name to the link, for easy retrieval.\n        links_dict: dict[str, LinkDescription] = {\n            l.name: l.mutable(validate=False) for l in links\n        }\n\n        # Create an empty list of frames if not provided.\n        frames = frames if frames is not None else []\n\n        # Create a dictionary that maps the frame name to the frame, for easy retrieval.\n        frames_dict = {frame.name: frame for frame in frames}\n\n        # Check that our parser correctly resolved the frame's parent to be a link.\n        for frame in frames:\n            assert frame.parent_name != \"\", frame\n            assert frame.parent_name is not None, frame\n            assert frame.parent_name != \"__model__\", frame\n            assert frame.parent_name not in frames_dict, frame\n\n        # ===========================================================\n        # Populate the kinematic graph with links, joints, and frames\n        # ===========================================================\n\n        # Check the existence of the root link.\n        if root_link_name not in links_dict:\n            raise ValueError(root_link_name)\n\n        # Reset the connections of the root link.\n        for link in links_dict.values():\n            link.children = tuple()\n\n        # Couple links and joints creating the kinematic graph.\n        for joint in joints:\n\n            # Get the parent and child links of the joint.\n            parent_link = links_dict[joint.parent.name]\n            child_link = links_dict[joint.child.name]\n\n            assert child_link.name == joint.child.name\n            assert parent_link.name == joint.parent.name\n\n            # Assign link's parent.\n            child_link.parent_name = parent_link.name\n\n            # Assign link's children and make sure they are unique.\n            if child_link.name not in {l.name for l in parent_link.children}:\n                with parent_link.mutable_context(Mutability.MUTABLE_NO_VALIDATION):\n                    parent_link.children = (*parent_link.children, child_link)\n\n        # Collect all the links of the kinematic graph.\n        all_links_in_graph = list(\n            KinematicGraph.breadth_first_search(root=links_dict[root_link_name])\n        )\n\n        # Get the names of all links in the kinematic graph.\n        all_link_names_in_graph = [l.name for l in all_links_in_graph]\n\n        # Collect all the joints of the kinematic graph.\n        all_joints_in_graph = [\n            joint\n            for joint in joints\n            if joint.parent.name in all_link_names_in_graph\n            and joint.child.name in all_link_names_in_graph\n        ]\n\n        # Get the names of all joints in the kinematic graph.\n        all_joint_names_in_graph = [j.name for j in all_joints_in_graph]\n\n        # Collect all the frames of the kinematic graph.\n        # Note: our parser ensures that the parent of a frame is not another frame.\n        all_frames_in_graph = [\n            frame for frame in frames if frame.parent_name in all_link_names_in_graph\n        ]\n\n        # Get the names of all frames in the kinematic graph.\n        all_frames_names_in_graph = [f.name for f in all_frames_in_graph]\n\n        # ============================\n        # Collect unconnected elements\n        # ============================\n\n        # Collect all the joints that are not part of the kinematic graph.\n        removed_joints = [j for j in joints if j.name not in all_joint_names_in_graph]\n\n        for joint in removed_joints:\n            msg = \"Joint '{}' is unconnected and it will be removed\"\n            logging.debug(msg=msg.format(joint.name))\n\n        # Collect all the links that are not part of the kinematic graph.\n        unconnected_links = [l for l in links if l.name not in all_link_names_in_graph]\n\n        # Update the unconnected links by removing their children. The other properties\n        # are left untouched, it's caller responsibility to post-process them if needed.\n        for link in unconnected_links:\n            link.children = tuple()\n            msg = \"Link '{}' won't be part of the kinematic graph because unconnected\"\n            logging.debug(msg=msg.format(link.name))\n\n        # Collect all the frames that are not part of the kinematic graph.\n        unconnected_frames = [\n            f for f in frames if f.name not in all_frames_names_in_graph\n        ]\n\n        for frame in unconnected_frames:\n            msg = \"Frame '{}' won't be part of the kinematic graph because unconnected\"\n            logging.debug(msg=msg.format(frame.name))\n\n        return (\n            links_dict[root_link_name].mutable(mutable=False),\n            list(set(joints) - set(removed_joints)),\n            all_frames_in_graph,\n            unconnected_links,\n            list(set(removed_joints)),\n            unconnected_frames,\n        )\n\n    def reduce(self, considered_joints: Sequence[str]) -> KinematicGraph:\n        \"\"\"\n        Reduce the kinematic graph by removing unspecified joints.\n\n        When a joint is removed, the mass and inertia of its child link are lumped\n        with those of its parent link, obtaining a new link that combines the two.\n        The description of the removed joint specifies the default angle (usually 0)\n        that is considered when the joint is removed.\n\n        Args:\n            considered_joints: A list of joint names to consider.\n\n        Returns:\n            The reduced kinematic graph.\n        \"\"\"\n\n        # The current object represents the complete kinematic graph\n        full_graph = self\n\n        # Get the names of the joints to remove\n        joint_names_to_remove = list(\n            set(full_graph.joint_names()) - set(considered_joints)\n        )\n\n        # Return early if there is no action to take\n        if len(joint_names_to_remove) == 0:\n            logging.info(\"The kinematic graph doesn't need to be reduced\")\n            return copy.deepcopy(self)\n\n        # Check if all considered joints are part of the full kinematic graph\n        if set(considered_joints) - {j.name for j in full_graph.joints}:\n            extra_j = set(considered_joints) - {j.name for j in full_graph.joints}\n            msg = f\"Not all joints to consider are part of the graph ({{{extra_j}}})\"\n            raise ValueError(msg)\n\n        # Extract data we need to modify from the full graph\n        links_dict = copy.deepcopy(full_graph.links_dict)\n        joints_dict = copy.deepcopy(full_graph.joints_dict)\n\n        # Create the object to compute forward kinematics.\n        fk = KinematicGraphTransforms(graph=full_graph)\n\n        # The following steps are implemented below in order to create the reduced graph:\n        #\n        # 1. Lump the mass of the removed links into their parent\n        # 2. Update the pose and parent link of joints having the removed link as parent\n        # 3. Create the reduced graph considering the removed links as frames\n        # 4. Resolve the pose of the frames wrt their reduced graph parent\n        #\n        # We name \"removed link\" the link to remove, and \"lumped link\" the new link that\n        # combines the removed link and its parent. The lumped link will share the frame\n        # of the removed link's parent and the inertial properties of the two links that\n        # have been combined.\n\n        # =======================================================\n        # 1. Lump the mass of the removed links into their parent\n        # =======================================================\n\n        # Get all the links to remove. They will be lumped with their parent.\n        links_to_remove = [\n            joint.child.name\n            for joint_name, joint in joints_dict.items()\n            if joint_name in joint_names_to_remove\n        ]\n\n        # Lump the mass and the inertia traversing the tree from the leaf to the root,\n        # this way we propagate these properties back even in the case when also the\n        # parent link of a removed joint has to be lumped with its parent.\n        for link in reversed(full_graph):\n            if link.name not in links_to_remove:\n                continue\n\n            # Get the link to remove and its parent, i.e. the lumped link\n            link_to_remove = links_dict[link.name]\n            parent_of_link_to_remove = links_dict[link.parent_name]\n\n            msg = \"Lumping chain: {}->({})->{}\"\n            logging.debug(\n                msg.format(\n                    link_to_remove.name,\n                    self.joints_connection_dict[\n                        parent_of_link_to_remove.name, link_to_remove.name\n                    ].name,\n                    parent_of_link_to_remove.name,\n                )\n            )\n\n            # Lump the link\n            lumped_link = parent_of_link_to_remove.lump_with(\n                link=link_to_remove,\n                lumped_H_removed=fk.relative_transform(\n                    relative_to=parent_of_link_to_remove.name, name=link_to_remove.name\n                ),\n            )\n\n            # Pop the original two links from the dictionary...\n            _ = links_dict.pop(link_to_remove.name)\n            _ = links_dict.pop(parent_of_link_to_remove.name)\n\n            # ... and insert the lumped link (having the same name of the parent)\n            links_dict[lumped_link.name] = lumped_link\n\n            # Insert back in the dict an entry from the removed link name to the new\n            # lumped link. We need this info later, when we process the remaining joints.\n            links_dict[link_to_remove.name] = lumped_link\n\n            # As a consequence of the back-insertion, we need to adjust the resulting\n            # lumped link of links that have been removed previously.\n            # Note: in the dictionary, only items whose key is not matching value.name\n            #       are links that have been removed.\n            for previously_removed_link_name in {\n                link_name\n                for link_name, link in links_dict.items()\n                if link_name != link.name and link.name == link_to_remove.name\n            }:\n                links_dict[previously_removed_link_name] = lumped_link\n\n        # ==============================================================================\n        # 2. Update the pose and parent link of joints having the removed link as parent\n        # ==============================================================================\n\n        # Find the joints having the removed links as parent\n        joints_with_removed_parent_link = [\n            joints_dict[joint_name]\n            for joint_name in considered_joints\n            if joints_dict[joint_name].parent.name in links_to_remove\n        ]\n\n        # Update the pose of all joints having as parent link a removed link\n        for joint in joints_with_removed_parent_link:\n            # Update the pose. Note that after the lumping process, the dict entry\n            # links_dict[joint.parent.name] contains the final lumped link\n            with joint.mutable_context(mutability=Mutability.MUTABLE):\n                joint.pose = fk.relative_transform(\n                    relative_to=links_dict[joint.parent.name].name, name=joint.name\n                )\n            with joint.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):\n                # Update the parent link\n                joint.parent = links_dict[joint.parent.name]\n\n        # ===================================================================\n        # 3. Create the reduced graph considering the removed links as frames\n        # ===================================================================\n\n        # Get all the original links from the full graph\n        full_graph_links_dict = copy.deepcopy(full_graph.links_dict)\n\n        # Get all the final links from the reduced graph\n        links_to_keep = [\n            l for link_name, l in links_dict.items() if link_name not in links_to_remove\n        ]\n\n        # Override the entries of the full graph with those of the reduced graph.\n        # Those that are not overridden will become frames.\n        for link in links_to_keep:\n            full_graph_links_dict[link.name] = link\n\n        # Create the reduced graph data. We pass the full list of links so that those\n        # that are not part of the graph will be returned as frames.\n        (\n            reduced_root_node,\n            reduced_joints,\n            reduced_frames,\n            unconnected_links,\n            unconnected_joints,\n            unconnected_frames,\n        ) = KinematicGraph._create_graph(\n            links=list(full_graph_links_dict.values()),\n            joints=[joints_dict[joint_name] for joint_name in considered_joints],\n            root_link_name=full_graph.root.name,\n        )\n\n        assert {f.name for f in self.frames}.isdisjoint(\n            {f.name for f in unconnected_frames + reduced_frames}\n        )\n\n        for link in unconnected_links:\n            logging.debug(msg=f\"Link '{link.name}' is unconnected and became a frame\")\n\n        # Create the reduced graph.\n        reduced_graph = KinematicGraph(\n            root=reduced_root_node,\n            joints=reduced_joints,\n            frames=self.frames + unconnected_links + reduced_frames,\n            root_pose=full_graph.root_pose,\n            _joints_removed=(\n                self._joints_removed\n                + unconnected_joints\n                + [joints_dict[name] for name in joint_names_to_remove]\n            ),\n        )\n\n        # ================================================================\n        # 4. Resolve the pose of the frames wrt their reduced graph parent\n        # ================================================================\n\n        # Build a new object to compute FK on the reduced graph.\n        fk_reduced = KinematicGraphTransforms(graph=reduced_graph)\n\n        # We need to adjust the pose of the frames since their parent link\n        # could have been removed by the reduction process.\n        for frame in reduced_graph.frames:\n\n            # Always find the real parent link of the frame\n            name_of_new_parent_link = fk_reduced.find_parent_link_of_frame(\n                name=frame.name\n            )\n            assert name_of_new_parent_link in reduced_graph, name_of_new_parent_link\n\n            # Notify the user if the parent link has changed.\n            if name_of_new_parent_link != frame.parent_name:\n                msg = \"New parent of frame '{}' is '{}'\"\n                logging.debug(msg=msg.format(frame.name, name_of_new_parent_link))\n\n            # Always recompute the pose of the frame, and set zero inertial params.\n            with frame.mutable_context(jaxsim.utils.Mutability.MUTABLE_NO_VALIDATION):\n\n                # Update kinematic parameters of the frame.\n                # Note that here we compute the transform using the FK object of the\n                # full model, so that we are sure that the kinematic is not altered.\n                frame.pose = fk.relative_transform(\n                    relative_to=name_of_new_parent_link, name=frame.name\n                )\n\n                # Update the parent link such that the pose is expressed in its frame.\n                frame.parent_name = name_of_new_parent_link\n\n                # Update dynamic parameters of the frame.\n                frame.mass = 0.0\n                frame.inertia = np.zeros_like(frame.inertia)\n\n        # Return the reduced graph.\n        return reduced_graph\n\n    def link_names(self) -> list[str]:\n        \"\"\"\n        Get the names of all links in the kinematic graph (i.e. the nodes).\n\n        Returns:\n            The list of link names.\n        \"\"\"\n        return list(self.links_dict.keys())\n\n    def joint_names(self) -> list[str]:\n        \"\"\"\n        Get the names of all joints in the kinematic graph (i.e. the edges).\n\n        Returns:\n            The list of joint names.\n        \"\"\"\n        return list(self.joints_dict.keys())\n\n    def frame_names(self) -> list[str]:\n        \"\"\"\n        Get the names of all frames in the kinematic graph.\n\n        Returns:\n            The list of frame names.\n        \"\"\"\n\n        return list(self.frames_dict.keys())\n\n    def print_tree(self) -> None:\n        \"\"\"\n        Print the tree structure of the kinematic graph.\n        \"\"\"\n\n        import pptree\n\n        root_node = self.root\n\n        pptree.print_tree(\n            root_node,\n            childattr=\"children\",\n            nameattr=\"name_and_index\",\n            horizontal=True,\n        )\n\n    @property\n    def joints_removed(self) -> list[JointDescription]:\n        \"\"\"\n        Get the list of joints removed during the graph reduction.\n\n        Returns:\n            The list of removed joints.\n        \"\"\"\n\n        return self._joints_removed\n\n    @staticmethod\n    def breadth_first_search(\n        root: LinkDescription,\n        sort_children: Callable[[Any], Any] | None = lambda link: link.name,\n    ) -> Iterable[LinkDescription]:\n        \"\"\"\n        Perform a breadth-first search (BFS) traversal of the kinematic graph.\n\n        Args:\n            root: The root link for BFS.\n            sort_children: A function to sort children of a node.\n\n        Yields:\n            The links in the kinematic graph in BFS order.\n        \"\"\"\n\n        # Initialize the queue with the root node.\n        queue = [root]\n\n        # We assume that nodes have unique names and mark a link as visited using\n        # its name. This speeds up considerably object comparison.\n        visited = []\n        visited.append(root.name)\n\n        yield root\n\n        while queue:\n\n            # Extract the first element of the queue.\n            l = queue.pop(0)\n\n            # Note: sorting the links with their name so that the order of children\n            # insertion does not matter when assigning the link index.\n            for child in sorted(l.children, key=sort_children):\n\n                if child.name in visited:\n                    continue\n\n                visited.append(child.name)\n                queue.append(child)\n\n                yield child\n\n    # =================\n    # Sequence protocol\n    # =================\n\n    def __iter__(self) -> Iterator[LinkDescription]:\n        yield from KinematicGraph.breadth_first_search(root=self.root)\n\n    def __reversed__(self) -> Iterable[LinkDescription]:\n        yield from reversed(list(iter(self)))\n\n    def __len__(self) -> int:\n        return len(list(iter(self)))\n\n    def __contains__(self, item: str | LinkDescription) -> bool:\n        if isinstance(item, str):\n            return item in self.link_names()\n\n        if isinstance(item, LinkDescription):\n            return item in set(iter(self))\n\n        raise TypeError(type(item).__name__)\n\n    def __getitem__(self, key: int | str) -> LinkDescription:\n        if isinstance(key, str):\n            if key not in self.link_names():\n                raise KeyError(key)\n\n            return self.links_dict[key]\n\n        if isinstance(key, int):\n            if key > len(self):\n                raise KeyError(key)\n\n            return list(iter(self))[key]\n\n        raise TypeError(type(key).__name__)\n\n    def count(self, value: LinkDescription) -> int:\n        \"\"\"\n        Count the occurrences of a link in the kinematic graph.\n        \"\"\"\n        return list(iter(self)).count(value)\n\n    def index(self, value: LinkDescription, start: int = 0, stop: int = -1) -> int:\n        \"\"\"\n        Find the index of a link in the kinematic graph.\n        \"\"\"\n        return list(iter(self)).index(value, start, stop)\n\n\n# ====================\n# Other useful classes\n# ====================\n\n\n@dataclasses.dataclass(frozen=True)\nclass KinematicGraphTransforms:\n    \"\"\"\n    Class to compute forward kinematics on a kinematic graph.\n\n    Attributes:\n        graph: The kinematic graph on which to compute forward kinematics.\n    \"\"\"\n\n    graph: KinematicGraph\n\n    _transform_cache: dict[str, npt.NDArray] = dataclasses.field(\n        default_factory=dict, init=False, repr=False, compare=False\n    )\n\n    _initial_joint_positions: dict[str, float] = dataclasses.field(\n        init=False, repr=False, compare=False\n    )\n\n    def __post_init__(self) -> None:\n\n        super().__setattr__(\n            \"_initial_joint_positions\",\n            {joint.name: joint.initial_position for joint in self.graph.joints},\n        )\n\n    @property\n    def initial_joint_positions(self) -> npt.NDArray:\n        \"\"\"\n        Get the initial joint positions of the kinematic graph.\n        \"\"\"\n\n        return np.atleast_1d(\n            np.array(list(self._initial_joint_positions.values()))\n        ).astype(float)\n\n    @initial_joint_positions.setter\n    def initial_joint_positions(\n        self,\n        positions: npt.NDArray | Sequence,\n        joint_names: Sequence[str] | None = None,\n    ) -> None:\n\n        joint_names = (\n            joint_names\n            if joint_names is not None\n            else list(self._initial_joint_positions.keys())\n        )\n\n        s = np.atleast_1d(np.array(positions).squeeze())\n\n        if s.size != len(joint_names):\n            raise ValueError(s.size, len(joint_names))\n\n        for joint_name in joint_names:\n            if joint_name not in self._initial_joint_positions:\n                raise ValueError(joint_name)\n\n        # Clear transform cache.\n        self._transform_cache.clear()\n\n        # Update initial joint positions.\n        for joint_name, position in zip(joint_names, s, strict=True):\n            self._initial_joint_positions[joint_name] = position\n\n    def transform(self, name: str) -> npt.NDArray:\n        \"\"\"\n        Compute the SE(3) transform of elements belonging to the kinematic graph.\n\n        Args:\n            name: The name of a link, a joint, or a frame.\n\n        Returns:\n            The 4x4 transform matrix of the element w.r.t. the model frame.\n        \"\"\"\n\n        # If the transform was already computed, return it.\n        if name in self._transform_cache:\n            return self._transform_cache[name]\n\n        # If the name is a joint, compute M_H_J transform.\n        if name in self.graph.joint_names():\n\n            # Get the joint.\n            joint = self.graph.joints_dict[name]\n            assert joint.name == name\n\n            # Get the transform of the parent link.\n            M_H_L = self.transform(name=joint.parent.name)\n\n            # Rename the pose of the predecessor joint frame w.r.t. its parent link.\n            L_H_pre = joint.pose\n\n            # Compute the joint transform from the predecessor to the successor frame.\n            pre_H_J = self.pre_H_suc(\n                joint_type=joint.jtype,\n                joint_axis=joint.axis,\n                joint_position=self._initial_joint_positions[joint.name],\n            )\n\n            # Compute the M_H_J transform.\n            self._transform_cache[name] = M_H_L @ L_H_pre @ pre_H_J\n            return self._transform_cache[name]\n\n        # If the name is a link, compute M_H_L transform.\n        if name in self.graph.link_names():\n\n            # Get the link.\n            link = self.graph.links_dict[name]\n\n            # Handle the pose between the __model__ frame and the root link.\n            if link.name == self.graph.root.name:\n                M_H_B = link.pose\n                return M_H_B\n\n            # Get the joint between the link and its parent.\n            parent_joint = self.graph.joints_connection_dict[\n                link.parent_name, link.name\n            ]\n\n            # Get the transform of the parent joint.\n            M_H_J = self.transform(name=parent_joint.name)\n\n            # Rename the pose of the link w.r.t. its parent joint.\n            J_H_L = link.pose\n\n            # Compute the M_H_L transform.\n            self._transform_cache[name] = M_H_J @ J_H_L\n            return self._transform_cache[name]\n\n        # It can only be a plain frame.\n        if name not in self.graph.frame_names():\n            raise ValueError(name)\n\n        # Get the frame.\n        frame = self.graph.frames_dict[name]\n\n        # Get the transform of the parent link.\n        M_H_L = self.transform(name=frame.parent_name)\n\n        # Rename the pose of the frame w.r.t. its parent link.\n        L_H_F = frame.pose\n\n        # Compute the M_H_F transform.\n        self._transform_cache[name] = M_H_L @ L_H_F\n        return self._transform_cache[name]\n\n    def relative_transform(self, relative_to: str, name: str) -> npt.NDArray:\n        \"\"\"\n        Compute the SE(3) relative transform of elements belonging to the kinematic graph.\n\n        Args:\n            relative_to: The name of the reference element.\n            name: The name of a link, a joint, or a frame.\n\n        Returns:\n            The 4x4 transform matrix of the element w.r.t. the desired frame.\n        \"\"\"\n\n        import jaxsim.math\n\n        M_H_target = self.transform(name=name)\n        M_H_R = self.transform(name=relative_to)\n\n        # Compute the relative transform R_H_target, where R is the reference frame,\n        # and i the frame of the desired link|joint|frame.\n        return np.array(jaxsim.math.Transform.inverse(M_H_R)) @ M_H_target\n\n    @staticmethod\n    def pre_H_suc(\n        joint_type: JointType,\n        joint_axis: npt.NDArray,\n        joint_position: float | None = None,\n    ) -> npt.NDArray:\n        \"\"\"\n        Compute the SE(3) transform from the predecessor to the successor frame.\n\n        Args:\n            joint_type: The type of the joint.\n            joint_axis: The axis of the joint.\n            joint_position: The position of the joint.\n\n        Returns:\n            The 4x4 transform matrix from the predecessor to the successor frame.\n        \"\"\"\n\n        import jaxsim.math\n\n        return np.array(\n            jaxsim.math.supported_joint_motion(joint_type, joint_position, joint_axis)\n        )\n\n    def find_parent_link_of_frame(self, name: str) -> str:\n        \"\"\"\n        Find the parent link of a frame.\n\n        Args:\n            name: The name of the frame.\n\n        Returns:\n            The name of the parent link of the frame.\n        \"\"\"\n\n        try:\n            frame = self.graph.frames_dict[name]\n        except KeyError as e:\n            raise ValueError(f\"Frame '{name}' not found in the kinematic graph\") from e\n\n        if frame.parent_name in self.graph.links_dict:\n            return frame.parent_name\n        if frame.parent_name in self.graph.frames_dict:\n            return self.find_parent_link_of_frame(name=frame.parent_name)\n\n        msg = f\"Failed to find parent element of frame '{name}' with name '{frame.parent_name}'\"\n        raise RuntimeError(msg)\n"
  },
  {
    "path": "src/jaxsim/parsers/rod/__init__.py",
    "content": "from . import parser, utils\nfrom .parser import build_model_description, extract_model_data\n"
  },
  {
    "path": "src/jaxsim/parsers/rod/meshes.py",
    "content": "import numpy as np\nimport trimesh\n\nVALID_AXIS = {\"x\": 0, \"y\": 1, \"z\": 2}\n\n\ndef extract_points_vertices(mesh: trimesh.Trimesh) -> np.ndarray:\n    \"\"\"\n    Extract the vertices of a mesh as points.\n    \"\"\"\n    return mesh.vertices\n\n\ndef extract_points_random_surface_sampling(mesh: trimesh.Trimesh, n) -> np.ndarray:\n    \"\"\"\n    Extract N random points from the surface of a mesh.\n\n    Args:\n        mesh: The mesh from which to extract points.\n        n: The number of points to extract.\n\n    Returns:\n        The extracted points (N x 3 array).\n    \"\"\"\n\n    return mesh.sample(n)\n\n\ndef extract_points_uniform_surface_sampling(\n    mesh: trimesh.Trimesh, n: int\n) -> np.ndarray:\n    \"\"\"\n    Extract N uniformly sampled points from the surface of a mesh.\n\n    Args:\n        mesh: The mesh from which to extract points.\n        n: The number of points to extract.\n\n    Returns:\n        The extracted points (N x 3 array).\n    \"\"\"\n\n    return trimesh.sample.sample_surface_even(mesh=mesh, count=n)[0]\n\n\ndef extract_points_select_points_over_axis(\n    mesh: trimesh.Trimesh, axis: str, direction: str, n: int\n) -> np.ndarray:\n    \"\"\"\n    Extract N points from a mesh along a specified axis. The points are selected based on their position along the axis.\n\n    Args:\n        mesh: The mesh from which to extract points.\n        axis: The axis along which to extract points.\n        direction: The direction along the axis from which to extract points. Valid values are \"higher\" and \"lower\".\n        n: The number of points to extract.\n\n    Returns:\n        The extracted points (N x 3 array).\n    \"\"\"\n\n    dirs = {\"higher\": np.s_[-n:], \"lower\": np.s_[:n]}\n    arr = mesh.vertices\n\n    # Sort rows lexicographically first, then columnar.\n    arr.sort(axis=0)\n    sorted_arr = arr[dirs[direction]]\n    return sorted_arr\n\n\ndef extract_points_aap(\n    mesh: trimesh.Trimesh,\n    axis: str,\n    upper: float | None = None,\n    lower: float | None = None,\n) -> np.ndarray:\n    \"\"\"\n    Extract points from a mesh along a specified axis within a specified range. The points are selected based on their position along the axis.\n\n    Args:\n        mesh: The mesh from which to extract points.\n        axis: The axis along which to extract points.\n        upper: The upper bound of the range.\n        lower: The lower bound of the range.\n\n    Returns:\n        The extracted points (N x 3 array).\n\n    Raises:\n        AssertionError: If the lower bound is greater than the upper bound.\n    \"\"\"\n\n    # Check bounds.\n    upper = upper if upper is not None else np.inf\n    lower = lower if lower is not None else -np.inf\n    assert lower < upper, \"Invalid bounds for axis-aligned plane\"\n\n    # Logic.\n    points = mesh.vertices[\n        (mesh.vertices[:, VALID_AXIS[axis]] >= lower)\n        & (mesh.vertices[:, VALID_AXIS[axis]] <= upper)\n    ]\n\n    return points\n"
  },
  {
    "path": "src/jaxsim/parsers/rod/parser.py",
    "content": "import dataclasses\nimport os\nimport pathlib\nfrom typing import NamedTuple\n\nimport jax.numpy as jnp\nimport numpy as np\nimport rod\n\nfrom jaxsim import logging\nfrom jaxsim.math import Quaternion\nfrom jaxsim.parsers import descriptions, kinematic_graph\n\nfrom . import utils\n\n\nclass SDFData(NamedTuple):\n    \"\"\"\n    Data extracted from an SDF resource useful to build a JaxSim model.\n    \"\"\"\n\n    model_name: str\n\n    fixed_base: bool\n    base_link_name: str\n\n    link_descriptions: list[descriptions.LinkDescription]\n    joint_descriptions: list[descriptions.JointDescription]\n    frame_descriptions: list[descriptions.LinkDescription]\n    collision_shapes: list[descriptions.CollisionShape]\n\n    sdf_model: rod.Model | None = None\n    model_pose: kinematic_graph.RootPose = kinematic_graph.RootPose()\n\n\ndef extract_model_data(\n    model_description: pathlib.Path | str | rod.Model | rod.Sdf,\n    model_name: str | None = None,\n    is_urdf: bool | None = None,\n) -> SDFData:\n    \"\"\"\n    Extract data from an SDF/URDF resource useful to build a JaxSim model.\n\n    Args:\n        model_description:\n            A path to an SDF/URDF file, a string containing its content, or\n            a pre-parsed/pre-built rod model.\n        model_name: The name of the model to extract from the SDF resource.\n        is_urdf:\n            Whether to force parsing the resource as a URDF file. Automatically\n            detected if not provided.\n\n    Returns:\n        The extracted model data.\n    \"\"\"\n\n    match model_description:\n        case rod.Model():\n            sdf_model = model_description\n        case rod.Sdf() | str() | pathlib.Path():\n            sdf_element = (\n                model_description\n                if isinstance(model_description, rod.Sdf)\n                else rod.Sdf.load(sdf=model_description, is_urdf=is_urdf)\n            )\n            if not sdf_element.models():\n                raise RuntimeError(\"Failed to find any model in SDF resource\")\n\n            # Assume the SDF resource has only one model, or the desired model name is given.\n            sdf_models = {m.name: m for m in sdf_element.models()}\n            sdf_model = (\n                sdf_element.models()[0]\n                if len(sdf_models) == 1\n                else sdf_models[model_name]\n            )\n\n    # Log model name.\n    logging.info(msg=f\"Found model '{sdf_model.name}' in SDF resource\")\n\n    # Jaxsim supports only models compatible with URDF, i.e. those having all links\n    # directly attached to their parent joint without additional roto-translations.\n    # Furthermore, the following switch also post-processes frames such that their\n    # pose is expressed wrt the parent link they are rigidly attached to.\n    sdf_model.switch_frame_convention(frame_convention=rod.FrameConvention.Urdf)\n\n    # Log type of base link.\n    logging.debug(\n        msg=f\"Model '{sdf_model.name}' is {'fixed-base' if sdf_model.is_fixed_base() else 'floating-base'}\"\n    )\n\n    # Log detected base link.\n    logging.debug(msg=f\"Considering '{sdf_model.get_canonical_link()}' as base link\")\n\n    # Pose of the model\n    if sdf_model.pose is None:\n        model_pose = kinematic_graph.RootPose()\n\n    else:\n        W_H_M = sdf_model.pose.transform()\n        model_pose = kinematic_graph.RootPose(\n            root_position=W_H_M[0:3, 3],\n            root_quaternion=Quaternion.from_dcm(dcm=W_H_M[0:3, 0:3]),\n        )\n\n    # ===========\n    # Parse links\n    # ===========\n\n    # Parse the links (unconnected).\n    links = [\n        descriptions.LinkDescription(\n            name=l.name,\n            mass=float(l.inertial.mass),\n            inertia=utils.from_sdf_inertial(inertial=l.inertial),\n            pose=l.pose.transform() if l.pose is not None else np.eye(4),\n        )\n        for l in sdf_model.links()\n        if l.inertial.mass > 0\n    ]\n\n    # Create a dictionary to find easily links.\n    links_dict: dict[str, descriptions.LinkDescription] = {l.name: l for l in links}\n\n    # ============\n    # Parse frames\n    # ============\n\n    # Parse the frames (unconnected).\n    frames = [\n        descriptions.LinkDescription(\n            name=f.name,\n            mass=jnp.array(0.0, dtype=float),\n            inertia=jnp.zeros(shape=(3, 3)),\n            parent_name=f.attached_to,\n            pose=f.pose.transform() if f.pose is not None else jnp.eye(4),\n        )\n        for f in sdf_model.frames()\n        if f.attached_to in links_dict\n    ]\n\n    # =========================\n    # Process fixed-base models\n    # =========================\n\n    # In this case, we need to get the pose of the joint that connects the base link\n    # to the world and combine their pose.\n    if sdf_model.is_fixed_base():\n        # Create a massless word link\n        world_link = descriptions.LinkDescription(\n            name=\"world\", mass=0, inertia=np.zeros(shape=(6, 6))\n        )\n\n        # Gather joints connecting fixed-base models to the world.\n        # TODO: the pose of this joint could be expressed wrt any arbitrary frame,\n        #       here we assume is expressed wrt the model. This also means that the\n        #       default model pose matches the pose of the fake \"world\" link.\n        joints_with_world_parent = [\n            descriptions.JointDescription(\n                name=j.name,\n                parent=world_link,\n                child=links_dict[j.child],\n                jtype=utils.joint_to_joint_type(joint=j),\n                axis=(\n                    np.array(j.axis.xyz.xyz)\n                    if j.axis is not None\n                    and j.axis.xyz is not None\n                    and j.axis.xyz.xyz is not None\n                    else None\n                ),\n                pose=j.pose.transform() if j.pose is not None else np.eye(4),\n            )\n            for j in sdf_model.joints()\n            if j.type == \"fixed\"\n            and j.parent == \"world\"\n            and j.child in links_dict\n            and j.pose.relative_to in {\"__model__\", \"world\", None}\n        ]\n\n        logging.debug(\n            f\"Found joints connecting to world: {[j.name for j in joints_with_world_parent]}\"\n        )\n\n        if len(joints_with_world_parent) != 1:\n            msg = \"Found more/less than one joint connecting a fixed-base model to the world\"\n            raise ValueError(msg + f\": {[j.name for j in joints_with_world_parent]}\")\n\n        base_link_name = joints_with_world_parent[0].child.name\n\n        msg = \"Combining the pose of base link '{}' with the pose of joint '{}'\"\n        logging.debug(msg.format(base_link_name, joints_with_world_parent[0].name))\n\n        # Combine the pose of the base link (child of the found fixed joint)\n        # with the pose of the fixed joint connecting with the world.\n        # Note: we assume it's a fixed joint and ignore any joint angle.\n        links_dict[base_link_name].mutable(validate=False).pose = (\n            joints_with_world_parent[0].pose @ links_dict[base_link_name].pose\n        )\n\n    # ============\n    # Parse joints\n    # ============\n\n    # Check that all joint poses are expressed w.r.t. their parent link.\n    for j in sdf_model.joints():\n        if j.pose is None:\n            continue\n\n        if j.parent == \"world\":\n            if j.pose.relative_to in {\"__model__\", \"world\", None}:\n                continue\n\n            raise ValueError(\"Pose of fixed joint connecting to 'world' link not valid\")\n\n        if j.pose.relative_to != j.parent:\n            msg = \"Pose of joint '{}' is not expressed wrt its parent link '{}'\"\n            raise ValueError(msg.format(j.name, j.parent))\n\n    # Parse the joints.\n    joints = [\n        descriptions.JointDescription(\n            name=j.name,\n            parent=links_dict[j.parent],\n            child=links_dict[j.child],\n            jtype=utils.joint_to_joint_type(joint=j),\n            axis=(\n                np.array(j.axis.xyz.xyz, dtype=float)\n                if j.axis is not None\n                and j.axis.xyz is not None\n                and j.axis.xyz.xyz is not None\n                else None\n            ),\n            pose=j.pose.transform() if j.pose is not None else np.eye(4),\n            initial_position=0.0,\n            position_limit=(\n                float(\n                    j.axis.limit.lower\n                    if j.axis is not None\n                    and j.axis.limit is not None\n                    and j.axis.limit.lower is not None\n                    else jnp.finfo(float).min\n                ),\n                float(\n                    j.axis.limit.upper\n                    if j.axis is not None\n                    and j.axis.limit is not None\n                    and j.axis.limit.upper is not None\n                    else jnp.finfo(float).max\n                ),\n            ),\n            friction_static=float(\n                j.axis.dynamics.friction\n                if j.axis is not None\n                and j.axis.dynamics is not None\n                and j.axis.dynamics.friction is not None\n                else 0.0\n            ),\n            friction_viscous=float(\n                j.axis.dynamics.damping\n                if j.axis is not None\n                and j.axis.dynamics is not None\n                and j.axis.dynamics.damping is not None\n                else 0.0\n            ),\n            position_limit_damper=float(\n                j.axis.limit.dissipation\n                if j.axis is not None\n                and j.axis.limit is not None\n                and j.axis.limit.dissipation is not None\n                else os.environ.get(\"JAXSIM_JOINT_POSITION_LIMIT_DAMPER\", 0.0)\n            ),\n            position_limit_spring=float(\n                j.axis.limit.stiffness\n                if j.axis is not None\n                and j.axis.limit is not None\n                and j.axis.limit.stiffness is not None\n                else os.environ.get(\"JAXSIM_JOINT_POSITION_LIMIT_SPRING\", 0.0)\n            ),\n        )\n        for j in sdf_model.joints()\n        if j.type in {\"revolute\", \"continuous\", \"prismatic\", \"fixed\"}\n        and j.parent != \"world\"\n        and j.child in links_dict\n    ]\n\n    # Create a dictionary to find the parent joint of the links.\n    joint_dict = {j.child.name: j.name for j in joints}\n\n    # Check that all the link poses are expressed wrt their parent joint.\n    for l in sdf_model.links():\n        if l.name not in links_dict:\n            continue\n\n        if l.pose is None:\n            continue\n\n        if l.name == sdf_model.get_canonical_link():\n            continue\n\n        if l.name not in joint_dict:\n            raise ValueError(f\"Failed to find parent joint of link '{l.name}'\")\n\n        if l.pose.relative_to != joint_dict[l.name]:\n            msg = \"Pose of link '{}' is not expressed wrt its parent joint '{}'\"\n            raise ValueError(msg.format(l.name, joint_dict[l.name]))\n\n    # ================\n    # Parse collisions\n    # ================\n\n    # Initialize the collision shapes\n    collisions: list[descriptions.CollisionShape] = []\n\n    # Parse the collisions\n    for link in sdf_model.links():\n        for collision in link.collisions():\n            if collision.geometry.box is not None:\n                box_collision = utils.create_box_collision(\n                    collision=collision,\n                    link_description=links_dict[link.name],\n                )\n\n                collisions.append(box_collision)\n                continue\n\n            if collision.geometry.sphere is not None:\n                sphere_collision = utils.create_sphere_collision(\n                    collision=collision,\n                    link_description=links_dict[link.name],\n                )\n\n                collisions.append(sphere_collision)\n                continue\n\n            if collision.geometry.mesh is not None:\n                if int(os.environ.get(\"JAXSIM_COLLISION_MESH_ENABLED\", \"0\")):\n                    logging.warning(\"Mesh collision support is still experimental.\")\n                    mesh_collision = utils.create_mesh_collision(\n                        collision=collision,\n                        link_description=links_dict[link.name],\n                        method=utils.meshes.extract_points_vertices,\n                    )\n\n                    collisions.append(mesh_collision)\n\n                else:\n                    logging.warning(\n                        f\"Skipping collision shape 'mesh' in link '{link.name}' because mesh collisions are disabled.\"\n                    )\n\n                continue\n\n            # Check any remaining non-None geometry types.\n            for attr_name in collision.geometry.__dict__:\n                if getattr(collision.geometry, attr_name) is not None:\n                    logging.warning(\n                        f\"Skipping collision shape '{attr_name}' in link '{link.name}' as not supported.\"\n                    )\n\n    return SDFData(\n        model_name=sdf_model.name,\n        link_descriptions=links,\n        joint_descriptions=joints,\n        frame_descriptions=frames,\n        collision_shapes=collisions,\n        fixed_base=sdf_model.is_fixed_base(),\n        base_link_name=sdf_model.get_canonical_link(),\n        model_pose=model_pose,\n        sdf_model=sdf_model,\n    )\n\n\ndef build_model_description(\n    model_description: pathlib.Path | str | rod.Model,\n    is_urdf: bool | None = None,\n) -> descriptions.ModelDescription:\n    \"\"\"\n    Build a model description from an SDF/URDF resource.\n\n    Args:\n        model_description: A path to an SDF/URDF file, a string containing its content,\n          or a pre-parsed/pre-built rod model.\n        is_urdf: Whether the force parsing the resource as a URDF file. Automatically\n            detected if not provided.\n\n    Returns:\n        The parsed model description.\n    \"\"\"\n\n    # Parse data from the SDF assuming it contains a single model.\n    sdf_data = extract_model_data(\n        model_description=model_description, model_name=None, is_urdf=is_urdf\n    )\n\n    # Build the intermediate representation used for building a JaxSim model.\n    # This process, beyond other operations, removes the fixed joints.\n    # Note: if the model is fixed-base, the fixed joint between world and the first\n    #       link is removed and the pose of the first link is updated.\n    #\n    # The whole process is:\n    # URDF/SDF ⟶ rod.Model ⟶ ModelDescription ⟶ JaxSimModel.\n    graph = descriptions.ModelDescription.build_model_from(\n        name=sdf_data.model_name,\n        links=sdf_data.link_descriptions,\n        joints=sdf_data.joint_descriptions,\n        frames=sdf_data.frame_descriptions,\n        collisions=sdf_data.collision_shapes,\n        fixed_base=sdf_data.fixed_base,\n        base_link_name=sdf_data.base_link_name,\n        model_pose=sdf_data.model_pose,\n        considered_joints=[\n            j.name\n            for j in sdf_data.joint_descriptions\n            if j.jtype is not descriptions.JointType.Fixed\n        ],\n    )\n\n    # Store the parsed SDF tree as extra info\n    graph = dataclasses.replace(graph, _extra_info={\"sdf_model\": sdf_data.sdf_model})\n\n    return graph\n"
  },
  {
    "path": "src/jaxsim/parsers/rod/utils.py",
    "content": "import os\nimport pathlib\nfrom collections.abc import Callable\nfrom typing import TypeVar\n\nimport numpy as np\nimport numpy.typing as npt\nimport rod\nimport trimesh\nfrom rod.utils.resolve_uris import resolve_local_uri\n\nimport jaxsim.typing as jtp\nfrom jaxsim import logging\nfrom jaxsim.math import Adjoint, Inertia\nfrom jaxsim.parsers import descriptions\nfrom jaxsim.parsers.rod import meshes\n\nMeshMappingMethod = TypeVar(\"MeshMappingMethod\", bound=Callable[..., npt.NDArray])\n\n\ndef from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:\n    \"\"\"\n    Extract the 6D inertia matrix from an SDF inertial element.\n\n    Args:\n        inertial: The SDF inertial element.\n\n    Returns:\n        The 6D inertia matrix of the link expressed in the link frame.\n    \"\"\"\n\n    # Extract the \"mass\" element.\n    m = inertial.mass\n\n    # Extract the \"inertia\" element.\n    inertia_element = inertial.inertia\n\n    ixx = inertia_element.ixx\n    iyy = inertia_element.iyy\n    izz = inertia_element.izz\n    ixy = inertia_element.ixy if inertia_element.ixy is not None else 0.0\n    ixz = inertia_element.ixz if inertia_element.ixz is not None else 0.0\n    iyz = inertia_element.iyz if inertia_element.iyz is not None else 0.0\n\n    # Build the 3x3 inertia matrix expressed in the CoM.\n    I_CoM = np.array(\n        [\n            [ixx, ixy, ixz],\n            [ixy, iyy, iyz],\n            [ixz, iyz, izz],\n        ]\n    )\n\n    # Build the 6x6 generalized inertia at the CoM.\n    M_CoM = Inertia.to_sixd(mass=m, com=np.zeros(3), I=I_CoM)\n\n    # Compute the transform from the inertial frame (CoM) to the link frame.\n    L_H_CoM = inertial.pose.transform() if inertial.pose is not None else np.eye(4)\n\n    # We need its inverse.\n    CoM_X_L = Adjoint.from_transform(transform=L_H_CoM, inverse=True)\n\n    # Express the CoM inertia matrix in the link frame L.\n    M_L = CoM_X_L.T @ M_CoM @ CoM_X_L\n\n    return M_L.astype(dtype=float)\n\n\ndef joint_to_joint_type(joint: rod.Joint) -> int:\n    \"\"\"\n    Extract the joint type from an SDF joint.\n\n    Args:\n        joint: The parsed SDF joint.\n\n    Returns:\n        The integer corresponding to the joint type.\n    \"\"\"\n\n    axis = joint.axis\n    joint_type = joint.type\n\n    if joint_type == \"fixed\":\n        return descriptions.JointType.Fixed\n\n    if not (axis.xyz is not None and axis.xyz.xyz is not None):\n        raise ValueError(\"Failed to read axis xyz data\")\n\n    # Make sure that the axis is a unary vector.\n    axis_xyz = np.array(axis.xyz.xyz).astype(float)\n    axis_xyz = axis_xyz / np.linalg.norm(axis_xyz)\n\n    if joint_type in {\"revolute\", \"continuous\"}:\n        return descriptions.JointType.Revolute\n\n    if joint_type == \"prismatic\":\n        return descriptions.JointType.Prismatic\n\n    raise ValueError(\"Joint not supported\", axis_xyz, joint_type)\n\n\ndef create_box_collision(\n    collision: rod.Collision, link_description: descriptions.LinkDescription\n) -> descriptions.BoxCollision:\n    \"\"\"\n    Create a box collision from an SDF collision element.\n\n    Args:\n        collision: The SDF collision element.\n        link_description: The link description.\n\n    Returns:\n        The box collision description.\n    \"\"\"\n\n    x, y, z = collision.geometry.box.size\n\n    center = np.array([x / 2, y / 2, z / 2])\n\n    # Define the bottom corners.\n    bottom_corners = np.array([[0, 0, 0], [x, 0, 0], [x, y, 0], [0, y, 0]])\n\n    # Conditionally add the top corners based on the environment variable.\n    top_corners = (\n        np.array([[0, 0, z], [x, 0, z], [x, y, z], [0, y, z]])\n        if os.environ.get(\"JAXSIM_COLLISION_USE_BOTTOM_ONLY\", \"0\").lower()\n        in {\n            \"false\",\n            \"0\",\n        }\n        else []\n    )\n\n    # Combine and shift by the center\n    box_corners = np.vstack([bottom_corners, *top_corners]) - center\n\n    H = collision.pose.transform() if collision.pose is not None else np.eye(4)\n\n    center_wrt_link = (H @ np.hstack([center, 1.0]))[0:-1]\n    box_corners_wrt_link = (\n        H @ np.hstack([box_corners, np.vstack([1.0] * box_corners.shape[0])]).T\n    )[0:3, :]\n\n    collidable_points = [\n        descriptions.CollidablePoint(\n            parent_link=link_description,\n            position=np.array(corner),\n            enabled=True,\n        )\n        for corner in box_corners_wrt_link.T\n    ]\n\n    return descriptions.BoxCollision(\n        collidable_points=collidable_points, center=center_wrt_link\n    )\n\n\ndef create_sphere_collision(\n    collision: rod.Collision, link_description: descriptions.LinkDescription\n) -> descriptions.SphereCollision:\n    \"\"\"\n    Create a sphere collision from an SDF collision element.\n\n    Args:\n        collision: The SDF collision element.\n        link_description: The link description.\n\n    Returns:\n        The sphere collision description.\n    \"\"\"\n\n    # From https://stackoverflow.com/a/26127012\n    def fibonacci_sphere(samples: int) -> npt.NDArray:\n        # Get the golden ratio in radians.\n        phi = np.pi * (3.0 - np.sqrt(5.0))\n\n        # Generate the points.\n        points = [\n            np.array(\n                [\n                    np.cos(phi * i)\n                    * np.sqrt(1 - (y := 1 - 2 * i / (samples - 1)) ** 2),\n                    y,\n                    np.sin(phi * i) * np.sqrt(1 - y**2),\n                ]\n            )\n            for i in range(samples)\n        ]\n\n        # Filter to keep only the bottom half if required.\n        if os.environ.get(\"JAXSIM_COLLISION_USE_BOTTOM_ONLY\", \"0\").lower() in {\n            \"true\",\n            \"1\",\n        }:\n            # Keep only the points with z <= 0.\n            points = [point for point in points if point[2] <= 0]\n\n        return np.vstack(points)\n\n    r = collision.geometry.sphere.radius\n\n    sphere_points = r * fibonacci_sphere(\n        samples=int(os.getenv(key=\"JAXSIM_COLLISION_SPHERE_POINTS\", default=\"50\"))\n    )\n\n    H = collision.pose.transform() if collision.pose is not None else np.eye(4)\n\n    center_wrt_link = (H @ np.hstack([0, 0, 0, 1.0]))[0:-1]\n\n    sphere_points_wrt_link = (\n        H @ np.hstack([sphere_points, np.vstack([1.0] * sphere_points.shape[0])]).T\n    )[0:3, :]\n\n    collidable_points = [\n        descriptions.CollidablePoint(\n            parent_link=link_description,\n            position=np.array(point),\n            enabled=True,\n        )\n        for point in sphere_points_wrt_link.T\n    ]\n\n    return descriptions.SphereCollision(\n        collidable_points=collidable_points, center=center_wrt_link\n    )\n\n\ndef create_mesh_collision(\n    collision: rod.Collision,\n    link_description: descriptions.LinkDescription,\n    method: MeshMappingMethod = None,\n) -> descriptions.MeshCollision:\n    \"\"\"\n    Create a mesh collision from an SDF collision element.\n\n    Args:\n        collision: The SDF collision element.\n        link_description: The link description.\n        method: The method to use for mesh wrapping.\n\n    Returns:\n        The mesh collision description.\n    \"\"\"\n\n    file = pathlib.Path(resolve_local_uri(uri=collision.geometry.mesh.uri))\n    file_type = file.suffix.replace(\".\", \"\")\n    mesh = trimesh.load_mesh(file, file_type=file_type)\n\n    if mesh.is_empty:\n        raise RuntimeError(f\"Failed to process '{file}' with trimesh\")\n\n    mesh.apply_scale(collision.geometry.mesh.scale)\n    logging.info(\n        msg=f\"Loading mesh {collision.geometry.mesh.uri} with scale {collision.geometry.mesh.scale}, file type '{file_type}'\"\n    )\n\n    if method is None:\n        method = meshes.VertexExtraction()\n        logging.debug(\"Using default Vertex Extraction method for mesh wrapping\")\n    else:\n        logging.debug(f\"Using method {method} for mesh wrapping\")\n\n    points = method(mesh=mesh)\n    logging.debug(f\"Extracted {len(points)} points from mesh\")\n\n    W_H_L = collision.pose.transform() if collision.pose is not None else np.eye(4)\n\n    # Extract translation from transformation matrix\n    W_p_L = W_H_L[:3, 3]\n    mesh_points_wrt_link = points @ W_H_L[:3, :3].T + W_p_L\n    collidable_points = [\n        descriptions.CollidablePoint(\n            parent_link=link_description,\n            position=point,\n            enabled=True,\n        )\n        for point in mesh_points_wrt_link\n    ]\n\n    return descriptions.MeshCollision(collidable_points=collidable_points, center=W_p_L)\n\n\ndef prepare_mesh_for_parametrization(\n    mesh_uri: str, scale: tuple[float, float, float] = (1.0, 1.0, 1.0)\n) -> dict:\n    \"\"\"\n    Load and prepare a mesh for parametric scaling with exact inertia computation.\n\n    This function loads a mesh, ensures it's watertight (crucial for volume/inertia\n    calculation), centers it, and returns the data needed for parametric scaling.\n\n    Args:\n        mesh_uri: URI/path to the mesh file.\n        scale: Initial scale factors to apply (from SDF/URDF).\n\n    Returns:\n        A dictionary containing:\n            - 'vertices': Centered mesh vertices as numpy array (Nx3)\n            - 'faces': Triangle faces as numpy array (Mx3 integer indices)\n            - 'offset': Original mesh centroid offset as numpy array (3,)\n            - 'uri': The mesh URI for reference\n            - 'is_watertight': Boolean indicating if mesh is watertight\n            - 'volume': The volume of the mesh (after scaling)\n    \"\"\"\n\n    # Load mesh\n    file = pathlib.Path(resolve_local_uri(uri=mesh_uri))\n    file_type = file.suffix.replace(\".\", \"\")\n    mesh = trimesh.load_mesh(file, file_type=file_type)\n\n    if mesh.is_empty:\n        raise RuntimeError(f\"Failed to process '{file}' with trimesh\")\n\n    # Apply initial scale from SDF/URDF\n    mesh.apply_scale(scale)\n\n    # Check and fix watertightness\n    is_watertight = mesh.is_watertight\n    if not is_watertight:\n        logging.warning(\n            f\"Mesh {mesh_uri} is not watertight. Computing convex hull for valid inertia.\"\n        )\n        mesh = mesh.convex_hull\n        is_watertight = True\n\n    # Store original centroid as offset\n    offset = mesh.centroid.copy()\n\n    # Center the mesh\n    mesh.vertices -= offset\n\n    return {\n        \"vertices\": np.array(mesh.vertices, dtype=np.float64),\n        \"faces\": np.array(mesh.faces, dtype=np.int32),\n        \"offset\": np.array(offset, dtype=np.float64),\n        \"uri\": mesh_uri,\n        \"is_watertight\": is_watertight,\n        \"volume\": mesh.volume,\n    }\n"
  },
  {
    "path": "src/jaxsim/rbda/__init__.py",
    "content": "from . import actuation, contacts\nfrom .aba import aba\nfrom .aba_parallel import aba_parallel\nfrom .collidable_points import collidable_points_pos_vel\nfrom .crba import crba\nfrom .forward_kinematics import forward_kinematics_model\nfrom .forward_kinematics_parallel import forward_kinematics_model_parallel\nfrom .jacobian import (\n    jacobian,\n    jacobian_derivative_full_doubly_left,\n    jacobian_full_doubly_left,\n)\nfrom .kinematic_constraints import compute_constraint_wrenches\nfrom .mass_inverse import mass_inverse\nfrom .rnea import rnea\n"
  },
  {
    "path": "src/jaxsim/rbda/aba.py",
    "content": "import jax\nimport jax.numpy as jnp\nimport jaxlie\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim.math import STANDARD_GRAVITY, Adjoint, Cross\n\nfrom . import utils\n\n\ndef aba(\n    model: js.model.JaxSimModel,\n    *,\n    base_position: jtp.VectorLike,\n    base_quaternion: jtp.VectorLike,\n    joint_positions: jtp.VectorLike,\n    base_linear_velocity: jtp.VectorLike,\n    base_angular_velocity: jtp.VectorLike,\n    joint_velocities: jtp.VectorLike,\n    joint_transforms: jtp.MatrixLike,\n    joint_forces: jtp.VectorLike | None = None,\n    link_forces: jtp.MatrixLike | None = None,\n    standard_gravity: jtp.FloatLike = STANDARD_GRAVITY,\n) -> tuple[jtp.Vector, jtp.Vector]:\n    \"\"\"\n    Compute forward dynamics using the Articulated Body Algorithm (ABA).\n\n    Args:\n        model: The model to consider.\n        base_position: The position of the base link.\n        base_quaternion: The quaternion of the base link.\n        joint_positions: The positions of the joints.\n        base_linear_velocity:\n            The linear velocity of the base link in inertial-fixed representation.\n        base_angular_velocity:\n            The angular velocity of the base link in inertial-fixed representation.\n        joint_velocities: The velocities of the joints.\n        joint_transforms: The parent-to-child transforms of the joints.\n        joint_forces: The forces applied to the joints.\n        link_forces:\n            The forces applied to the links expressed in the world frame.\n        standard_gravity: The standard gravity constant.\n\n    Returns:\n        A tuple containing the base acceleration in inertial-fixed representation\n        and the joint accelerations that result from the applications of the given\n        joint and link forces.\n\n    Note:\n        The algorithm expects a quaternion with unit norm.\n    \"\"\"\n\n    W_p_B, W_Q_B, s, W_v_WB, ṡ, _, _, τ, W_f, W_g = utils.process_inputs(\n        model=model,\n        base_position=base_position,\n        base_quaternion=base_quaternion,\n        joint_positions=joint_positions,\n        base_linear_velocity=base_linear_velocity,\n        base_angular_velocity=base_angular_velocity,\n        joint_velocities=joint_velocities,\n        base_linear_acceleration=None,\n        base_angular_acceleration=None,\n        joint_accelerations=None,\n        joint_forces=joint_forces,\n        link_forces=link_forces,\n        standard_gravity=standard_gravity,\n    )\n\n    W_g = jnp.atleast_2d(W_g).T\n    W_v_WB = jnp.atleast_2d(W_v_WB).T\n\n    # Get the 6D spatial inertia matrices of all links.\n    M = js.model.link_spatial_inertia_matrices(model=model)\n\n    # Get the parent array λ(i).\n    # Note: λ(0) must not be used, it's initialized to -1.\n    λ = model.kin_dyn_parameters.parent_array\n\n    # Compute the base transform.\n    W_H_B = jaxlie.SE3.from_rotation_and_translation(\n        rotation=jaxlie.SO3(wxyz=W_Q_B),\n        translation=W_p_B,\n    )\n\n    # Compute 6D transforms of the base velocity.\n    W_X_B = W_H_B.adjoint()\n    B_X_W = W_H_B.inverse().adjoint()\n\n    # Extract the parent-to-child adjoints of the joints.\n    i_X_λi = jnp.asarray(joint_transforms)\n\n    # Extract the joint motion subspaces.\n    S = model.kin_dyn_parameters.motion_subspaces\n\n    # Allocate buffers.\n    v = jnp.zeros(shape=(model.number_of_links(), 6, 1))\n    c = jnp.zeros(shape=(model.number_of_links(), 6, 1))\n    pA = jnp.zeros(shape=(model.number_of_links(), 6, 1))\n    MA = jnp.zeros(shape=(model.number_of_links(), 6, 6))\n\n    # Allocate the buffer of transforms link -> base.\n    i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))\n    i_X_0 = i_X_0.at[0].set(jnp.eye(6))\n\n    # Initialize base quantities.\n    if model.floating_base():\n\n        # Base velocity v₀ in body-fixed representation.\n        v_0 = B_X_W @ W_v_WB\n        v = v.at[0].set(v_0)\n\n        # Initialize the articulated-body inertia (Mᴬ) of base link.\n        MA_0 = M[0]\n        MA = MA.at[0].set(MA_0)\n\n        # Initialize the articulated-body bias force (pᴬ) of the base link.\n        pA_0 = Cross.vx_star(v[0]) @ MA[0] @ v[0] - W_X_B.T @ jnp.vstack(W_f[0])\n        pA = pA.at[0].set(pA_0)\n\n    # ======\n    # Pass 1\n    # ======\n\n    Pass1Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]\n    pass_1_carry: Pass1Carry = (v, c, MA, pA, i_X_0)\n\n    # Propagate kinematics and initialize AB inertia and AB bias forces.\n    def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]:\n\n        ii = i - 1\n        v, c, MA, pA, i_X_0 = carry\n\n        # Project the joint velocity into its motion subspace.\n        vJ = S[i] * ṡ[ii]\n\n        # Propagate the link velocity.\n        v_i = i_X_λi[i] @ v[λ[i]] + vJ\n        v = v.at[i].set(v_i)\n\n        c_i = Cross.vx(v[i]) @ vJ\n        c = c.at[i].set(c_i)\n\n        # Initialize the articulated-body inertia.\n        MA_i = jnp.array(M[i])\n        MA = MA.at[i].set(MA_i)\n\n        # Compute the link-to-base transform.\n        i_Xi_0 = i_X_λi[i] @ i_X_0[λ[i]]\n        i_X_0 = i_X_0.at[i].set(i_Xi_0)\n\n        # Compute link-to-world transform for the 6D force.\n        i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T\n\n        # Initialize articulated-body bias force.\n        pA_i = Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(W_f[i])\n        pA = pA.at[i].set(pA_i)\n\n        return (v, c, MA, pA, i_X_0), None\n\n    (v, c, MA, pA, i_X_0), _ = (\n        jax.lax.scan(\n            f=loop_body_pass1,\n            init=pass_1_carry,\n            xs=jnp.arange(start=1, stop=model.number_of_links()),\n        )\n        if model.number_of_links() > 1\n        else [(v, c, MA, pA, i_X_0), None]\n    )\n\n    # ======\n    # Pass 2\n    # ======\n\n    U = jnp.zeros_like(S)\n    d = jnp.zeros(shape=(model.number_of_links(), 1))\n    u = jnp.zeros(shape=(model.number_of_links(), 1))\n\n    Pass2Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]\n    pass_2_carry: Pass2Carry = (U, d, u, MA, pA)\n\n    def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]:\n\n        ii = i - 1\n        U, d, u, MA, pA = carry\n\n        U_i = MA[i] @ S[i]\n        U = U.at[i].set(U_i)\n\n        d_i = S[i].T @ U[i]\n        d = d.at[i].set(d_i.squeeze())\n\n        u_i = τ[ii] - S[i].T @ pA[i]\n        u = u.at[i].set(u_i.squeeze())\n\n        # Compute the articulated-body inertia and bias force of this link.\n        Ma = MA[i] - U[i] / d[i] @ U[i].T\n        pa = pA[i] + Ma @ c[i] + U[i] * (u[i] / d[i])\n\n        # Propagate them to the parent, handling the base link.\n        def propagate(\n            MA_pA: tuple[jtp.Matrix, jtp.Matrix],\n        ) -> tuple[jtp.Matrix, jtp.Matrix]:\n\n            MA, pA = MA_pA\n\n            MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]\n            MA = MA.at[λ[i]].set(MA_λi)\n\n            pA_λi = pA[λ[i]] + i_X_λi[i].T @ pa\n            pA = pA.at[λ[i]].set(pA_λi)\n\n            return MA, pA\n\n        MA, pA = jax.lax.cond(\n            pred=jnp.logical_or(λ[i] != 0, model.floating_base()),\n            true_fun=propagate,\n            false_fun=lambda MA_pA: MA_pA,\n            operand=(MA, pA),\n        )\n\n        return (U, d, u, MA, pA), None\n\n    (U, d, u, MA, pA), _ = (\n        jax.lax.scan(\n            f=loop_body_pass2,\n            init=pass_2_carry,\n            xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),\n        )\n        if model.number_of_links() > 1\n        else [(U, d, u, MA, pA), None]\n    )\n\n    # ======\n    # Pass 3\n    # ======\n\n    if model.floating_base():\n        a0 = jnp.linalg.solve(-MA[0], pA[0])\n    else:\n        a0 = -B_X_W @ W_g\n\n    s̈ = jnp.zeros_like(s)\n    a = jnp.zeros_like(v).at[0].set(a0)\n\n    Pass3Carry = tuple[jtp.Matrix, jtp.Vector]\n    pass_3_carry = (a, s̈)\n\n    def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]:\n\n        ii = i - 1\n        a, s̈ = carry\n\n        # Propagate the link acceleration.\n        a_i = i_X_λi[i] @ a[λ[i]] + c[i]\n\n        # Compute the joint acceleration.\n        s̈_ii = (u[i] - U[i].T @ a_i) / d[i]\n        s̈ = s̈.at[ii].set(s̈_ii.squeeze())\n\n        # Sum the joint acceleration to the parent link acceleration.\n        a_i = a_i + S[i] * s̈[ii]\n        a = a.at[i].set(a_i)\n\n        return (a, s̈), None\n\n    (a, s̈), _ = (\n        jax.lax.scan(\n            f=loop_body_pass3,\n            init=pass_3_carry,\n            xs=jnp.arange(1, model.number_of_links()),\n        )\n        if model.number_of_links() > 1\n        else [(a, s̈), None]\n    )\n\n    # ==============\n    # Adjust outputs\n    # ==============\n\n    # TODO: remove vstack and shape=(6, 1)?\n    if model.floating_base():\n        # Convert the base acceleration to inertial-fixed representation,\n        # and add gravity.\n        B_a_WB = a[0]\n        W_a_WB = W_X_B @ B_a_WB + W_g\n    else:\n        W_a_WB = jnp.zeros(6)\n\n    return W_a_WB.squeeze(), jnp.atleast_1d(s̈.squeeze())\n"
  },
  {
    "path": "src/jaxsim/rbda/aba_parallel.py",
    "content": "import math\n\nimport jax\nimport jax.numpy as jnp\nimport jaxlie\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim.math import STANDARD_GRAVITY, Adjoint, Cross\n\nfrom . import utils\n\n\ndef aba_parallel(\n    model: js.model.JaxSimModel,\n    *,\n    base_position: jtp.VectorLike,\n    base_quaternion: jtp.VectorLike,\n    joint_positions: jtp.VectorLike,\n    base_linear_velocity: jtp.VectorLike,\n    base_angular_velocity: jtp.VectorLike,\n    joint_velocities: jtp.VectorLike,\n    joint_transforms: jtp.MatrixLike,\n    joint_forces: jtp.VectorLike | None = None,\n    link_forces: jtp.MatrixLike | None = None,\n    standard_gravity: jtp.FloatLike = STANDARD_GRAVITY,\n) -> tuple[jtp.Vector, jtp.Vector]:\n    \"\"\"\n    Compute forward dynamics using a hybrid parallel ABA.\n\n    Passes 1 and 3 use pointer jumping in O(log D) parallel steps.\n    Pass 2 uses level-parallel processing in O(D) steps because the backward\n    inertia accumulation is not associative.\n\n    The interface and semantics are identical to :func:`aba`,\n    but passes 1 and 3 are parallelized via pointer jumping.\n    \"\"\"\n\n    W_p_B, W_Q_B, _, W_v_WB, ṡ, _, _, τ, W_f, W_g = utils.process_inputs(\n        model=model,\n        base_position=base_position,\n        base_quaternion=base_quaternion,\n        joint_positions=joint_positions,\n        base_linear_velocity=base_linear_velocity,\n        base_angular_velocity=base_angular_velocity,\n        joint_velocities=joint_velocities,\n        base_linear_acceleration=None,\n        base_angular_acceleration=None,\n        joint_accelerations=None,\n        joint_forces=joint_forces,\n        link_forces=link_forces,\n        standard_gravity=standard_gravity,\n    )\n\n    W_g = jnp.atleast_2d(W_g).T\n    W_v_WB = jnp.atleast_2d(W_v_WB).T\n\n    # Get the 6D spatial inertia matrices of all links.\n    M = js.model.link_spatial_inertia_matrices(model=model)\n\n    # Get the parent array λ(i).\n    λ = model.kin_dyn_parameters.parent_array\n\n    # Get the tree level structure for level-parallel processing.\n    level_nodes = jnp.asarray(model.kin_dyn_parameters.level_nodes)\n    level_mask = jnp.asarray(model.kin_dyn_parameters.level_mask)\n    n_levels = level_nodes.shape[0]\n\n    # Compute the base transform.\n    W_H_B = jaxlie.SE3.from_rotation_and_translation(\n        rotation=jaxlie.SO3(wxyz=W_Q_B),\n        translation=W_p_B,\n    )\n\n    # Compute 6D transforms of the base velocity.\n    W_X_B = W_H_B.adjoint()\n    B_X_W = W_H_B.inverse().adjoint()\n\n    # Extract the parent-to-child adjoints of the joints.\n    i_X_λi = jnp.asarray(joint_transforms)\n\n    # Extract the joint motion subspaces.\n    S = model.kin_dyn_parameters.motion_subspaces\n\n    n = model.number_of_links()\n\n    # Parent array with root self-loop.\n    # Note: λ(0) is set to 0 to enable root self-referencing.\n    ptr0 = jnp.asarray(λ).at[0].set(0)\n\n    # Number of pointer-jumping rounds.\n    n_rounds = max(1, math.ceil(math.log2(max(n_levels, 2))))\n\n    # ======\n    # Pass 1\n    # ======\n\n    # Two coupled affine recurrences propagated via pointer jumping:\n    #   v_i = i_X_λi[i] @ v_parent + vJ_i\n    #   T_i = i_X_λi[i] @ T_parent\n    #\n    # Associative operator on (A, b, T):\n    #   compose(parent, child) = (A @ A_p, A @ b_p + b, A @ T_p)\n\n    # Local transforms and joint velocities.\n    ṡ_col = jnp.atleast_1d(ṡ).reshape(-1, 1)  # (n_joints, 1)\n    ṡ_padded = jnp.concatenate([jnp.zeros((1, 1)), ṡ_col])  # (n, 1)\n    vJ = S * ṡ_padded[:, :, None]  # (n, 6, 1)\n\n    # Initialize pointer-jumping state for each node.\n    A = i_X_λi.copy()  # (n, 6, 6)\n    b = vJ.copy()  # (n, 6, 1)\n    T = i_X_λi.copy()  # (n, 6, 6)\n\n    # Root initial values.\n    if model.floating_base():\n        v_0 = B_X_W @ W_v_WB\n        A = A.at[0].set(jnp.eye(6))\n        b = b.at[0].set(v_0)\n        T = T.at[0].set(jnp.eye(6))\n    else:\n        A = A.at[0].set(jnp.eye(6))\n        b = b.at[0].set(jnp.zeros((6, 1)))\n        T = T.at[0].set(jnp.eye(6))\n\n    ptr = ptr0.copy()\n    done = jnp.arange(n) == 0\n\n    def _pass1_jump(carry, _):\n        A, b, T, ptr, done = carry\n        need = ~done\n\n        A_par = A[ptr]\n        b_par = b[ptr]\n        T_par = T[ptr]\n\n        # Associative compose.\n        A_new = jnp.where(need[:, None, None], A @ A_par, A)\n        b_new = jnp.where(need[:, None, None], A @ b_par + b, b)\n        T_new = jnp.where(need[:, None, None], A @ T_par, T)\n\n        ptr_new = jnp.where(need, ptr[ptr], ptr)\n        done_new = done | done[ptr]\n\n        return (A_new, b_new, T_new, ptr_new, done_new), None\n\n    (_, v, i_X_0, _, _), _ = (\n        jax.lax.scan(\n            f=_pass1_jump,\n            init=(A, b, T, ptr, done),\n            xs=jnp.arange(n_rounds),\n        )\n        if n > 1\n        else ((A, b, T, ptr, done), None)\n    )\n\n    # v now contains the 6D body velocity of every link.\n    # i_X_0 contains the body-to-base transform for every link.\n\n    # Compute c, MA, pA for all nodes in parallel.\n    def _init_node(node_i):\n        vJ_i = S[node_i] * ṡ_padded[node_i]\n        c_i = Cross.vx(v[node_i]) @ vJ_i\n        MA_i = M[node_i]\n        i_Xf_W = Adjoint.inverse(i_X_0[node_i] @ B_X_W).T\n        pA_i = Cross.vx_star(v[node_i]) @ M[node_i] @ v[node_i] - i_Xf_W @ jnp.vstack(\n            W_f[node_i]\n        )\n        return c_i, MA_i, pA_i\n\n    c, MA, pA = jax.vmap(_init_node)(jnp.arange(n))\n\n    # Override base MA and pA if floating base.\n    if model.floating_base():\n        MA = MA.at[0].set(M[0])\n        pA_0 = Cross.vx_star(v[0]) @ M[0] @ v[0] - W_X_B.T @ jnp.vstack(W_f[0])\n        pA = pA.at[0].set(pA_0)\n\n    # ======\n    # Pass 2\n    # ======\n\n    # The Schur complement and multi-child scatter-add make this pass\n    # non-associative, so it remains level-parallel.\n\n    U = jnp.zeros_like(S)\n    d = jnp.ones(shape=(n, 1))  # Ones to avoid NaN for the base node.\n    u = jnp.zeros(shape=(n, 1))\n\n    def _masked_scatter_add(arr, indices, values, m):\n        \"\"\"Add values[j] to arr[indices[j]] only where m[j] is True.\"\"\"\n        mask = jnp.reshape(m, m.shape + (1,) * (values.ndim - 1))\n        masked_values = jnp.where(mask, values, jnp.zeros_like(values))\n        return arr.at[indices].add(masked_values)\n\n    def _pass2_level(carry, level_idx):\n        U, d, u, MA, pA = carry\n        actual_level = n_levels - 1 - level_idx\n        nodes = level_nodes[actual_level]\n        mask = level_mask[actual_level]\n\n        def _process_node_pass2(node_i):\n            # Clamp index to avoid out-of-bounds for padded entries.\n            ii = jnp.maximum(node_i, 1) - 1\n            parent = λ[node_i]\n\n            U_i = MA[node_i] @ S[node_i]\n            d_i = (S[node_i].T @ U_i).squeeze()\n            u_i = (τ[ii] - S[node_i].T @ pA[node_i]).squeeze()\n\n            Ma_i = MA[node_i] - U_i / d_i @ U_i.T\n            pa_i = pA[node_i] + Ma_i @ c[node_i] + U_i * (u_i / d_i)\n\n            Ma_parent = i_X_λi[node_i].T @ Ma_i @ i_X_λi[node_i]\n            pa_parent = i_X_λi[node_i].T @ pa_i\n\n            return U_i, d_i, u_i, Ma_parent, pa_parent, parent\n\n        U_lev, d_lev, u_lev, Ma_par, pa_par, parents = jax.vmap(_process_node_pass2)(\n            nodes\n        )\n\n        mask_6x1 = mask[:, None, None]\n        mask_1 = mask[:, None]\n\n        U = carry[0].at[nodes].set(jnp.where(mask_6x1, U_lev, carry[0][nodes]))\n        d = carry[1].at[nodes].set(jnp.where(mask_1, d_lev[:, None], carry[1][nodes]))\n        u = carry[2].at[nodes].set(jnp.where(mask_1, u_lev[:, None], carry[2][nodes]))\n\n        should_propagate = jnp.where(\n            model.floating_base(),\n            mask,\n            jnp.logical_and(mask, parents != 0),\n        )\n\n        MA = _masked_scatter_add(carry[3], parents, Ma_par, should_propagate)\n        pA = _masked_scatter_add(carry[4], parents, pa_par, should_propagate)\n\n        return (U, d, u, MA, pA), None\n\n    n_backward_levels = n_levels - 1\n    (U, d, u, MA, pA), _ = (\n        jax.lax.scan(\n            f=_pass2_level,\n            init=(U, d, u, MA, pA),\n            xs=jnp.arange(n_backward_levels),\n        )\n        if n_backward_levels > 0\n        else ((U, d, u, MA, pA), None)\n    )\n\n    # ======\n    # Pass 3\n    # ======\n\n    # The acceleration recurrence is an affine recurrence:\n    #   a_i = P_i @ i_X_λi[i] @ a_parent + P_i @ c_i + S_i * u_i / d_i\n    # where P_i = I - S_i @ U_i^T / d_i is the 6x6 projection matrix.\n\n    if model.floating_base():\n        a0 = jnp.linalg.solve(-MA[0], pA[0])\n    else:\n        a0 = -B_X_W @ W_g\n\n    # Pre-compute the affine recurrence coefficients for all nodes.\n    def _init_pass3(node_i):\n        P_i = jnp.eye(6) - S[node_i] @ U[node_i].T / d[node_i]\n        A_i = P_i @ i_X_λi[node_i]\n        b_i = P_i @ c[node_i] + S[node_i] * (u[node_i] / d[node_i])\n        return A_i, b_i\n\n    A, b = jax.vmap(_init_pass3)(jnp.arange(n))\n\n    # Root acceleration is known.\n    A = A.at[0].set(jnp.eye(6))\n    b = b.at[0].set(a0)\n\n    # Pointer jumping for the affine recurrence.\n    ptr = ptr0.copy()\n    done = jnp.arange(n) == 0\n\n    def _pass3_jump(carry, _):\n        A, b, ptr, done = carry\n        need = ~done\n\n        A_par = A[ptr]\n        b_par = b[ptr]\n\n        # Associative compose.\n        A_new = jnp.where(need[:, None, None], A @ A_par, A)\n        b_new = jnp.where(need[:, None, None], A @ b_par + b, b)\n\n        ptr_new = jnp.where(need, ptr[ptr], ptr)\n        done_new = done | done[ptr]\n\n        return (A_new, b_new, ptr_new, done_new), None\n\n    (_, a, _, _), _ = (\n        jax.lax.scan(\n            f=_pass3_jump,\n            init=(A, b, ptr, done),\n            xs=jnp.arange(n_rounds),\n        )\n        if n > 1\n        else ((A, b, ptr, done), None)\n    )\n\n    # Recover joint accelerations: s̈_i = (u_i - U_i^T @ a_before_i) / d_i\n    # where a_before_i = i_X_λi[i] @ a_parent + c_i.\n    a_λi = a[ptr0]\n    a_before = i_X_λi @ a_λi + c\n    Ut_a = (U.transpose(0, 2, 1) @ a_before).squeeze(-1)  # (n, 1)\n    s̈ = (u - Ut_a) / d  # (n, 1)\n\n    # ==============\n    # Adjust outputs\n    # ==============\n\n    if model.floating_base():\n        B_a_WB = a[0]\n        W_a_WB = W_X_B @ B_a_WB + W_g\n    else:\n        W_a_WB = jnp.zeros(6)\n\n    # Joint accelerations: skip base index, take indices 1..n-1.\n    s̈_out = s̈[1:]\n\n    return W_a_WB.squeeze(), jnp.atleast_1d(s̈_out.squeeze())\n"
  },
  {
    "path": "src/jaxsim/rbda/actuation/__init__.py",
    "content": "from .common import ActuationParams\n"
  },
  {
    "path": "src/jaxsim/rbda/actuation/common.py",
    "content": "import dataclasses\n\nimport jax_dataclasses\nfrom jax_dataclasses import Static\n\nimport jaxsim.typing as jtp\nfrom jaxsim.utils import JaxsimDataclass\n\n\n@jax_dataclasses.pytree_dataclass\nclass ActuationParams(JaxsimDataclass):\n    \"\"\"\n    Parameters class for the actuation model.\n    \"\"\"\n\n    torque_max: jtp.Float = dataclasses.field(default=3000.0)  # (Nm)\n    omega_th: jtp.Float = dataclasses.field(default=30.0)  # (rad/s)\n    omega_max: jtp.Float = dataclasses.field(default=100.0)  # (rad/s)\n    enable_friction: Static[bool] = dataclasses.field(default=True)\n"
  },
  {
    "path": "src/jaxsim/rbda/collidable_points.py",
    "content": "import jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim.math import Skew\n\n\ndef collidable_points_pos_vel(\n    model: js.model.JaxSimModel,\n    *,\n    link_transforms: jtp.Matrix,\n    link_velocities: jtp.Matrix,\n) -> tuple[jtp.Matrix, jtp.Matrix]:\n    \"\"\"\n\n    Compute the position and linear velocity of the enabled collidable points in the world frame.\n\n    Args:\n        model: The model to consider.\n        link_transforms: The transforms from the world frame to each link.\n        link_velocities: The linear and angular velocities of each link.\n\n    Returns:\n        A tuple containing the position and linear velocity of the enabled collidable points.\n    \"\"\"\n\n    # Get the indices of the enabled collidable points.\n    indices_of_enabled_collidable_points = (\n        model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points\n    )\n\n    parent_link_idx_of_enabled_collidable_points = jnp.array(\n        model.kin_dyn_parameters.contact_parameters.body, dtype=int\n    )[indices_of_enabled_collidable_points]\n\n    L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[\n        indices_of_enabled_collidable_points\n    ]\n\n    if len(indices_of_enabled_collidable_points) == 0:\n        return jnp.array(0).astype(float), jnp.empty(0).astype(float)\n\n    def process_point_kinematics(\n        Li_p_C: jtp.Vector, parent_body: jtp.Int\n    ) -> tuple[jtp.Vector, jtp.Vector]:\n\n        # Compute the position of the collidable point.\n        W_p_Ci = (link_transforms[parent_body] @ jnp.hstack([Li_p_C, 1]))[0:3]\n\n        # Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}.\n        CW_vl_WCi = (\n            jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()])\n            @ link_velocities[parent_body].squeeze()\n        )\n\n        return W_p_Ci, CW_vl_WCi\n\n    # Process all the collidable points in parallel.\n    W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)(\n        L_p_Ci,\n        parent_link_idx_of_enabled_collidable_points,\n    )\n\n    return W_p_Ci, CW_vl_WC\n"
  },
  {
    "path": "src/jaxsim/rbda/contacts/__init__.py",
    "content": "from . import relaxed_rigid, rigid, soft\nfrom .common import ContactModel, ContactsParams\nfrom .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams\nfrom .rigid import RigidContacts, RigidContactsParams\nfrom .soft import SoftContacts, SoftContactsParams\n\nContactParamsTypes = (\n    SoftContactsParams | RigidContactsParams | RelaxedRigidContactsParams\n)\n"
  },
  {
    "path": "src/jaxsim/rbda/contacts/common.py",
    "content": "from __future__ import annotations\n\nimport abc\nimport functools\n\nimport jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.terrain\nimport jaxsim.typing as jtp\nfrom jaxsim.math import STANDARD_GRAVITY\nfrom jaxsim.utils import JaxsimDataclass\n\ntry:\n    from typing import Self\nexcept ImportError:\n    from typing_extensions import Self\n\n\nMAX_STIFFNESS = 1e6\nMAX_DAMPING = 1e4\n\n\n@functools.partial(jax.jit, static_argnames=(\"terrain\",))\ndef compute_penetration_data(\n    p: jtp.VectorLike,\n    v: jtp.VectorLike,\n    terrain: jaxsim.terrain.Terrain,\n) -> tuple[jtp.Float, jtp.Float, jtp.Vector]:\n    \"\"\"\n    Compute the penetration data (depth, rate, and terrain normal) of a collidable point.\n\n    Args:\n        p: The position of the collidable point.\n        v:\n            The linear velocity of the point (linear component of the mixed 6D velocity\n            of the implicit frame `C = (W_p_C, [W])` associated to the point).\n        terrain: The considered terrain.\n\n    Returns:\n        A tuple containing the penetration depth, the penetration velocity,\n        and the considered terrain normal.\n    \"\"\"\n\n    # Pre-process the position and the linear velocity of the collidable point.\n    W_ṗ_C = jnp.array(v).squeeze()\n    px, py, pz = jnp.array(p).squeeze()\n\n    # Compute the terrain normal and the contact depth.\n    n̂ = terrain.normal(x=px, y=py).squeeze()\n    h = jnp.array([0, 0, terrain.height(x=px, y=py) - pz])\n\n    # Compute the penetration depth normal to the terrain.\n    δ = jnp.maximum(0.0, jnp.dot(h, n̂))\n\n    # Compute the penetration normal velocity.\n    δ_dot = -jnp.dot(W_ṗ_C, n̂)\n\n    # Enforce the penetration rate to be zero when the penetration depth is zero.\n    δ_dot = jnp.where(δ > 0, δ_dot, 0.0)\n\n    return δ, δ_dot, n̂\n\n\nclass ContactsParams(JaxsimDataclass):\n    \"\"\"\n    Abstract class representing the parameters of a contact model.\n\n    Note:\n        This class is supposed to store only the tunable parameters of the contact\n        model, i.e. all those parameters that can be changed during runtime.\n        If the contact model has also static parameters, they should be stored\n        in the corresponding `ContactModel` class.\n    \"\"\"\n\n    @classmethod\n    @abc.abstractmethod\n    def build(cls: type[Self], **kwargs) -> Self:\n        \"\"\"\n        Create a `ContactsParams` instance with specified parameters.\n\n        Returns:\n            The `ContactsParams` instance.\n        \"\"\"\n        pass\n\n    def build_default_from_jaxsim_model(\n        self: type[Self],\n        model: js.model.JaxSimModel,\n        *,\n        stiffness: jtp.FloatLike | None = None,\n        damping: jtp.FloatLike | None = None,\n        standard_gravity: jtp.FloatLike = STANDARD_GRAVITY,\n        static_friction_coefficient: jtp.FloatLike = 0.5,\n        max_penetration: jtp.FloatLike = 0.001,\n        number_of_active_collidable_points_steady_state: jtp.IntLike = 1,\n        damping_ratio: jtp.FloatLike = 1.0,\n        p: jtp.FloatLike = 0.5,\n        q: jtp.FloatLike = 0.5,\n        **kwargs,\n    ) -> Self:\n        \"\"\"\n        Create a `ContactsParams` instance with default parameters.\n\n        Args:\n            model: The robot model considered by the contact model.\n            stiffness: The stiffness of the contact model.\n            damping: The damping of the contact model.\n            standard_gravity: The standard gravity acceleration.\n            static_friction_coefficient: The static friction coefficient.\n            max_penetration: The maximum penetration depth.\n            number_of_active_collidable_points_steady_state:\n                The number of active collidable points in steady state.\n            damping_ratio: The damping ratio.\n            p: The first parameter of the contact model.\n            q: The second parameter of the contact model.\n            **kwargs: Optional additional arguments.\n\n        Returns:\n            The `ContactsParams` instance.\n\n        Note:\n            The `stiffness` is intended as the terrain stiffness in the Soft Contacts model,\n            while it is the Baumgarte stabilization stiffness in the Rigid Contacts model.\n\n            The `damping` is intended as the terrain damping in the Soft Contacts model,\n            while it is the Baumgarte stabilization damping in the Rigid Contacts model.\n\n            The `damping_ratio` parameter allows to operate on the following conditions:\n            - ξ > 1.0: over-damped\n            - ξ = 1.0: critically damped\n            - ξ < 1.0: under-damped\n        \"\"\"\n\n        # Use symbols for input parameters.\n        ξ = damping_ratio\n        δ_max = max_penetration\n        μc = static_friction_coefficient\n        nc = number_of_active_collidable_points_steady_state\n\n        # Compute the total mass of the model.\n        m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum()\n\n        # Compute the stiffness to get the desired steady-state penetration.\n        # Note that this is dependent on the non-linear exponent used in\n        # the damping term of the Hunt/Crossley model.\n        if stiffness is None:\n            # Compute the average support force on each collidable point.\n            f_average = m * standard_gravity / nc\n\n            stiffness = f_average / jnp.power(δ_max, 1 + p)\n            stiffness = jnp.clip(stiffness, 0, MAX_STIFFNESS)\n\n        # Compute the damping using the damping ratio.\n        critical_damping = 2 * jnp.sqrt(stiffness * m)\n        if damping is None:\n            damping = ξ * critical_damping\n            damping = jnp.clip(damping, 0, MAX_DAMPING)\n\n        return self.build(\n            K=stiffness,\n            D=damping,\n            mu=μc,\n            p=p,\n            q=q,\n            **kwargs,\n        )\n\n    @abc.abstractmethod\n    def valid(self, **kwargs) -> jtp.BoolLike:\n        \"\"\"\n        Check if the parameters are valid.\n\n        Returns:\n            True if the parameters are valid, False otherwise.\n        \"\"\"\n        pass\n\n\nclass ContactModel(JaxsimDataclass):\n    \"\"\"\n    Abstract class representing a contact model.\n    \"\"\"\n\n    @classmethod\n    @abc.abstractmethod\n    def build(\n        cls: type[Self],\n        **kwargs,\n    ) -> Self:\n        \"\"\"\n        Create a `ContactModel` instance with specified parameters.\n\n        Returns:\n            The `ContactModel` instance.\n        \"\"\"\n\n        pass\n\n    @abc.abstractmethod\n    def compute_contact_forces(\n        self,\n        model: js.model.JaxSimModel,\n        data: js.data.JaxSimModelData,\n        **kwargs,\n    ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:\n        \"\"\"\n        Compute the contact forces.\n\n        Args:\n            model: The robot model considered by the contact model.\n            data: The data of the considered model.\n            **kwargs: Optional additional arguments, specific to the contact model.\n\n        Returns:\n            A tuple containing as first element the computed 6D contact force applied to\n            the contact points and expressed in the world frame, and as second element\n            a dictionary of optional additional information.\n        \"\"\"\n\n        pass\n\n    @classmethod\n    def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]:\n        \"\"\"\n        Build zero state variables of the contact model.\n\n        Args:\n            model: The robot model considered by the contact model.\n\n        Note:\n            There are contact models that require to extend the state vector of the\n            integrated ODE system with additional variables. Our integrators are\n            capable of operating on a generic state, as long as it is a PyTree.\n            This method builds the zero state variables of the contact model as a\n            dictionary of JAX arrays.\n\n        Returns:\n            A dictionary storing the zero state variables of the contact model.\n        \"\"\"\n\n        return {}\n\n    @property\n    def _parameters_class(self) -> type[ContactsParams]:\n        \"\"\"\n        Return the class of the contact parameters.\n\n        Returns:\n            The class of the contact parameters.\n        \"\"\"\n        import importlib\n\n        return getattr(\n            importlib.import_module(\"jaxsim.rbda.contacts\"),\n            (\n                self.__name__ + \"Params\"\n                if isinstance(self, type)\n                else self.__class__.__name__ + \"Params\"\n            ),\n        )\n\n    @abc.abstractmethod\n    def update_contact_state(\n        self: type[Self], old_contact_state: dict[str, jtp.Array]\n    ) -> dict[str, jtp.Array]:\n        \"\"\"\n        Update the contact state.\n\n        Args:\n            old_contact_state: The old contact state.\n\n        Returns:\n            The updated contact state.\n        \"\"\"\n\n    @abc.abstractmethod\n    def update_velocity_after_impact(\n        self: type[Self], model: js.model.JaxSimModel, data: js.data.JaxSimModelData\n    ) -> js.data.JaxSimModelData:\n        \"\"\"\n        Update the velocity after an impact.\n\n        Args:\n            model: The robot model considered by the contact model.\n            data: The data of the considered model.\n\n        Returns:\n            The updated data of the considered model.\n        \"\"\"\n"
  },
  {
    "path": "src/jaxsim/rbda/contacts/relaxed_rigid.py",
    "content": "from __future__ import annotations\n\nimport dataclasses\nfrom collections.abc import Callable\nfrom typing import Any\n\nimport jax\nimport jax.numpy as jnp\nimport jax_dataclasses\nimport optax\nfrom optax.tree_utils import tree_get\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr\n\nfrom . import common, soft\n\ntry:\n    from typing import Self\nexcept ImportError:\n    from typing_extensions import Self\n\n\ntry:\n    from optax.tree_utils import tree_norm\nexcept ImportError:\n    from optax.tree_utils import tree_l2_norm as tree_norm\n\n\n@jax_dataclasses.pytree_dataclass\nclass RelaxedRigidContactsParams(common.ContactsParams):\n    \"\"\"Parameters of the relaxed rigid contacts model.\"\"\"\n\n    # Time constant\n    time_constant: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array(0.02, dtype=float)\n    )\n\n    # Adimensional damping coefficient\n    damping_coefficient: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array(1.0, dtype=float)\n    )\n\n    # Minimum impedance\n    d_min: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array(0.9, dtype=float)\n    )\n\n    # Maximum impedance\n    d_max: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array(0.95, dtype=float)\n    )\n\n    # Width\n    width: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array(0.001, dtype=float)\n    )\n\n    # Midpoint\n    midpoint: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array(0.5, dtype=float)\n    )\n\n    # Power exponent\n    power: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array(2.0, dtype=float)\n    )\n\n    # Stiffness\n    K: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array(0.0, dtype=float)\n    )\n\n    # Damping\n    D: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array(0.0, dtype=float)\n    )\n\n    # Friction coefficient\n    mu: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array(0.005, dtype=float)\n    )\n\n    def __hash__(self) -> int:\n        from jaxsim.utils.wrappers import HashedNumpyArray\n\n        return hash(\n            (\n                HashedNumpyArray(self.time_constant),\n                HashedNumpyArray(self.damping_coefficient),\n                HashedNumpyArray(self.d_min),\n                HashedNumpyArray(self.d_max),\n                HashedNumpyArray(self.width),\n                HashedNumpyArray(self.midpoint),\n                HashedNumpyArray(self.power),\n                HashedNumpyArray(self.K),\n                HashedNumpyArray(self.D),\n                HashedNumpyArray(self.mu),\n            )\n        )\n\n    def __eq__(self, other: RelaxedRigidContactsParams) -> bool:\n        if not isinstance(other, RelaxedRigidContactsParams):\n            return False\n\n        return hash(self) == hash(other)\n\n    @classmethod\n    def build(\n        cls: type[Self],\n        *,\n        time_constant: jtp.FloatLike | None = None,\n        damping_coefficient: jtp.FloatLike | None = None,\n        d_min: jtp.FloatLike | None = None,\n        d_max: jtp.FloatLike | None = None,\n        width: jtp.FloatLike | None = None,\n        midpoint: jtp.FloatLike | None = None,\n        power: jtp.FloatLike | None = None,\n        K: jtp.FloatLike | None = None,\n        D: jtp.FloatLike | None = None,\n        mu: jtp.FloatLike | None = None,\n        **kwargs,\n    ) -> Self:\n        \"\"\"Create a `RelaxedRigidContactsParams` instance.\"\"\"\n\n        def default(name: str):\n            return cls.__dataclass_fields__[name].default_factory()\n\n        return cls(\n            time_constant=jnp.array(\n                (\n                    time_constant\n                    if time_constant is not None\n                    else default(\"time_constant\")\n                ),\n                dtype=float,\n            ),\n            damping_coefficient=jnp.array(\n                (\n                    damping_coefficient\n                    if damping_coefficient is not None\n                    else default(\"damping_coefficient\")\n                ),\n                dtype=float,\n            ),\n            d_min=jnp.array(\n                d_min if d_min is not None else default(\"d_min\"), dtype=float\n            ),\n            d_max=jnp.array(\n                d_max if d_max is not None else default(\"d_max\"), dtype=float\n            ),\n            width=jnp.array(\n                width if width is not None else default(\"width\"), dtype=float\n            ),\n            midpoint=jnp.array(\n                midpoint if midpoint is not None else default(\"midpoint\"), dtype=float\n            ),\n            power=jnp.array(\n                power if power is not None else default(\"power\"), dtype=float\n            ),\n            K=jnp.array(\n                K if K is not None else default(\"K\"),\n                dtype=float,\n            ),\n            D=jnp.array(D if D is not None else default(\"D\"), dtype=float),\n            mu=jnp.array(mu if mu is not None else default(\"mu\"), dtype=float),\n        )\n\n    def valid(self) -> jtp.BoolLike:\n        \"\"\"Check if the parameters are valid.\"\"\"\n\n        return bool(\n            jnp.all(self.time_constant >= 0.0)\n            and jnp.all(self.damping_coefficient > 0.0)\n            and jnp.all(self.d_min >= 0.0)\n            and jnp.all(self.d_max <= 1.0)\n            and jnp.all(self.d_min <= self.d_max)\n            and jnp.all(self.width >= 0.0)\n            and jnp.all(self.midpoint >= 0.0)\n            and jnp.all(self.power >= 0.0)\n            and jnp.all(self.mu >= 0.0)\n        )\n\n\n@jax_dataclasses.pytree_dataclass\nclass RelaxedRigidContacts(common.ContactModel):\n    \"\"\"Relaxed rigid contacts model.\"\"\"\n\n    _solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field(\n        default=(\"tol\", \"maxiter\", \"memory_size\", \"scale_init_precond\"), kw_only=True\n    )\n    _solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field(\n        default=(1e-6, 50, 10, False), kw_only=True\n    )\n\n    @property\n    def solver_options(self) -> dict[str, Any]:\n        \"\"\"Get the solver options.\"\"\"\n\n        return dict(\n            zip(\n                self._solver_options_keys,\n                self._solver_options_values,\n                strict=True,\n            )\n        )\n\n    @classmethod\n    def build(\n        cls: type[Self],\n        solver_options: dict[str, Any] | None = None,\n        **kwargs,\n    ) -> Self:\n        \"\"\"\n        Create a `RelaxedRigidContacts` instance with specified parameters.\n\n        Args:\n            solver_options: The options to pass to the L-BFGS solver.\n            **kwargs: The parameters of the relaxed rigid contacts model.\n\n        Returns:\n            The `RelaxedRigidContacts` instance.\n        \"\"\"\n\n        # Get the default solver options.\n        default_solver_options = dict(\n            zip(cls._solver_options_keys, cls._solver_options_values, strict=True)\n        )\n\n        # Create the solver options to set by combining the default solver options\n        # with the user-provided solver options.\n        solver_options = default_solver_options | (\n            solver_options if solver_options is not None else {}\n        )\n\n        # Make sure that the solver options are hashable.\n        # We need to check this because the solver options are static.\n        try:\n            hash(tuple(solver_options.values()))\n        except TypeError as exc:\n            raise ValueError(\n                \"The values of the solver options must be hashable.\"\n            ) from exc\n\n        return cls(\n            _solver_options_keys=tuple(solver_options.keys()),\n            _solver_options_values=tuple(solver_options.values()),\n            **kwargs,\n        )\n\n    def update_contact_state(\n        self: type[Self], old_contact_state: dict[str, jtp.Array]\n    ) -> dict[str, jtp.Array]:\n        \"\"\"\n        Update the contact state.\n\n        Args:\n            old_contact_state: The old contact state.\n\n        Returns:\n            The updated contact state.\n        \"\"\"\n\n        return {}\n\n    def update_velocity_after_impact(\n        self: type[Self], model: js.model.JaxSimModel, data: js.data.JaxSimModelData\n    ) -> js.data.JaxSimModelData:\n        \"\"\"\n        Update the velocity after an impact.\n\n        Args:\n            model: The robot model considered by the contact model.\n            data: The data of the considered model.\n\n        Returns:\n            The updated data of the considered model.\n        \"\"\"\n\n        return data\n\n    @jax.jit\n    def compute_contact_forces(\n        self,\n        model: js.model.JaxSimModel,\n        data: js.data.JaxSimModelData,\n        *,\n        link_forces: jtp.MatrixLike | None = None,\n        joint_force_references: jtp.VectorLike | None = None,\n    ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:\n        \"\"\"\n        Compute the contact forces.\n\n        Args:\n            model: The model to consider.\n            data: The data of the considered model.\n            link_forces:\n                Optional `(n_links, 6)` matrix of external forces acting on the links,\n                expressed in the same representation of data.\n            joint_force_references:\n                Optional `(n_joints,)` vector of joint forces.\n\n        Returns:\n            A tuple containing as first element the computed contact forces in inertial representation.\n        \"\"\"\n\n        link_forces = jnp.atleast_2d(\n            jnp.array(link_forces, dtype=float).squeeze()\n            if link_forces is not None\n            else jnp.zeros((model.number_of_links(), 6))\n        )\n\n        joint_force_references = jnp.atleast_1d(\n            jnp.array(joint_force_references, dtype=float).squeeze()\n            if joint_force_references is not None\n            else jnp.zeros(model.number_of_joints())\n        )\n\n        references = js.references.JaxSimModelReferences.build(\n            model=model,\n            data=data,\n            velocity_representation=data.velocity_representation,\n            link_forces=link_forces,\n            joint_force_references=joint_force_references,\n        )\n\n        # Compute the position and linear velocities (mixed representation) of\n        # all collidable points belonging to the robot.\n        position, velocity = js.contact.collidable_point_kinematics(\n            model=model, data=data\n        )\n\n        # Compute the penetration depth and velocity of the collidable points.\n        # Note that this function considers the penetration in the normal direction.\n        δ, _, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))(\n            position, velocity, model.terrain\n        )\n\n        # Compute the position in the constraint frame.\n        position_constraint = jax.vmap(lambda δ, n̂: -δ * n̂)(δ, n̂)\n\n        # Compute the regularization terms.\n        a_ref, r, *_ = self._regularizers(\n            model=model,\n            position_constraint=position_constraint,\n            velocity_constraint=velocity,\n            parameters=model.contact_params,\n        )\n\n        # Compute the transforms of the implicit frames corresponding to the\n        # collidable points.\n        W_H_C = js.contact.transforms(model=model, data=data)\n\n        with (\n            data.switch_velocity_representation(VelRepr.Mixed),\n            references.switch_velocity_representation(VelRepr.Mixed),\n        ):\n            BW_ν = data.generalized_velocity\n\n            BW_ν̇_free = jnp.hstack(\n                js.model.forward_dynamics_aba(\n                    model=model,\n                    data=data,\n                    link_forces=references.link_forces(model=model, data=data),\n                    joint_forces=references.joint_force_references(model=model),\n                )\n            )\n\n            M_inv = js.model.free_floating_mass_matrix_inverse(model=model, data=data)\n\n            # Compute the linear part of the Jacobian of the collidable points\n            Jl_WC = jnp.vstack(\n                jax.vmap(lambda J, δ: J * (δ > 0))(\n                    js.contact.jacobian(model=model, data=data)[:, :3, :], δ\n                )\n            )\n\n            # Compute the linear part of the Jacobian derivative of the collidable points\n            J̇l_WC = jnp.vstack(\n                jax.vmap(lambda J̇, δ: J̇ * (δ > 0))(\n                    js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ\n                ),\n            )\n\n        # Compute the Delassus matrix directly using J and J̇.\n        G_contacts = Jl_WC @ M_inv @ Jl_WC.T\n\n        # Compute the free mixed linear acceleration of the collidable points.\n        CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇l_WC @ BW_ν\n\n        # Calculate quantities for the linear optimization problem.\n        R = jnp.diag(r)\n        A = G_contacts + R\n        b = CW_al_free_WC - a_ref\n\n        # Create the objective function to minimize as a lambda computing the cost\n        # from the optimized variables x.\n        objective = lambda x, A, b: jnp.sum(jnp.square(A @ x + b))\n\n        # ========================================\n        # Helper function to run the L-BFGS solver\n        # ========================================\n\n        def run_optimization(\n            init_params: jtp.Vector,\n            fun: Callable,\n            opt: optax.GradientTransformationExtraArgs,\n            maxiter: int,\n            tol: float,\n        ) -> tuple[jtp.Vector, optax.OptState]:\n\n            # Get the function to compute the loss and the gradient w.r.t. its inputs.\n            value_and_grad_fn = optax.value_and_grad_from_state(fun)\n\n            # Initialize the carry of the following loop.\n            OptimizationCarry = tuple[jtp.Vector, optax.OptState]\n            init_carry: OptimizationCarry = (init_params, opt.init(params=init_params))\n\n            def step(carry: OptimizationCarry) -> OptimizationCarry:\n\n                params, state = carry\n\n                value, grad = value_and_grad_fn(\n                    params,\n                    state=state,\n                    A=A,\n                    b=b,\n                )\n\n                updates, state = opt.update(\n                    updates=grad,\n                    state=state,\n                    params=params,\n                    value=value,\n                    grad=grad,\n                    value_fn=fun,\n                    A=A,\n                    b=b,\n                )\n\n                params = optax.apply_updates(params, updates)\n\n                return params, state\n\n            # TODO: maybe fix the number of iterations and switch to scan?\n            def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool:\n\n                _, state = carry\n\n                iter_num = tree_get(state, \"count\")\n                grad = tree_get(state, \"grad\")\n                err = tree_norm(grad)\n\n                return (iter_num == 0) | ((iter_num < maxiter) & (err >= tol))\n\n            final_params, final_state = jax.lax.while_loop(\n                continuing_criterion, step, init_carry\n            )\n\n            return final_params, final_state\n\n        # ======================================\n        # Compute the contact forces with L-BFGS\n        # ======================================\n\n        # Initialize the optimized forces with a linear Hunt/Crossley model.\n        init_params = jax.vmap(\n            lambda p, v: soft.SoftContacts.hunt_crossley_contact_model(\n                position=p,\n                velocity=v,\n                terrain=model.terrain,\n                K=1e6,\n                D=2e3,\n                p=0.5,\n                q=0.5,\n                # No tangential initial forces.\n                mu=0.0,\n                tangential_deformation=jnp.zeros(3),\n            )[0]\n        )(position, velocity).flatten()\n\n        # Get the solver options.\n        solver_options = self.solver_options\n\n        # Extract the options corresponding to the convergence criteria.\n        # All the remaining options are passed to the solver.\n        tol = solver_options.pop(\"tol\")\n        maxiter = solver_options.pop(\"maxiter\")\n\n        solve_fn = lambda *_: run_optimization(\n            init_params=init_params,\n            fun=objective,\n            opt=optax.lbfgs(**solver_options),\n            tol=tol,\n            maxiter=maxiter,\n        )\n\n        # Compute the 3D linear force in C[W] frame.\n        solution, _ = jax.lax.custom_linear_solve(\n            lambda x: A @ x,\n            -b,\n            solve=solve_fn,\n            symmetric=True,\n            has_aux=True,\n        )\n\n        # Reshape the optimized solution to be a matrix of 3D contact forces.\n        CW_fl_C = solution.reshape(-1, 3)\n\n        # Convert the contact forces from mixed to inertial-fixed representation.\n        W_f_C = jax.vmap(\n            lambda CW_fl_C, W_H_C: (\n                ModelDataWithVelocityRepresentation.other_representation_to_inertial(\n                    array=jnp.zeros(6).at[0:3].set(CW_fl_C),\n                    transform=W_H_C,\n                    other_representation=VelRepr.Mixed,\n                    is_force=True,\n                )\n            ),\n        )(CW_fl_C, W_H_C)\n\n        return W_f_C, {}\n\n    @staticmethod\n    def _regularizers(\n        model: js.model.JaxSimModel,\n        position_constraint: jtp.Vector,\n        velocity_constraint: jtp.Vector,\n        parameters: RelaxedRigidContactsParams,\n    ) -> tuple:\n        \"\"\"\n        Compute the contact jacobian and the reference acceleration.\n\n        Args:\n            model: The jaxsim model.\n            position_constraint: The position of the collidable points in the constraint frame.\n            velocity_constraint: The velocity of the collidable points in the constraint frame.\n            parameters: The parameters of the relaxed rigid contacts model.\n\n        Returns:\n            A tuple containing the reference acceleration, the regularization matrix,\n            the stiffness, and the damping.\n        \"\"\"\n\n        # Extract the parameters of the contact model.\n        Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ = (\n            getattr(parameters, field)\n            for field in (\n                \"time_constant\",\n                \"damping_coefficient\",\n                \"d_min\",\n                \"d_max\",\n                \"width\",\n                \"midpoint\",\n                \"power\",\n                \"K\",\n                \"D\",\n                \"mu\",\n            )\n        )\n\n        # Get the indices of the enabled collidable points.\n        indices_of_enabled_collidable_points = (\n            model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points\n        )\n\n        parent_link_idx_of_enabled_collidable_points = jnp.array(\n            model.kin_dyn_parameters.contact_parameters.body, dtype=int\n        )[indices_of_enabled_collidable_points]\n\n        # Compute the 6D inertia matrices of all links.\n        M_L = js.model.link_spatial_inertia_matrices(model=model)\n\n        def imp_aref(\n            pos: jtp.Vector,\n            vel: jtp.Vector,\n        ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector]:\n            \"\"\"\n            Calculate impedance and offset acceleration in constraint frame.\n\n            Args:\n                pos: position in constraint frame.\n                vel: velocity in constraint frame.\n\n            Returns:\n                ξ: computed impedance\n                a_ref: offset acceleration in constraint frame\n                K: computed stiffness\n                D: computed damping\n            \"\"\"\n\n            imp_x = jnp.abs(pos) / width\n\n            imp_a = (1.0 / jnp.power(mid, p - 1)) * jnp.power(imp_x, p)\n            imp_b = 1 - (1.0 / jnp.power(1 - mid, p - 1)) * jnp.power(1 - imp_x, p)\n            imp_y = jnp.where(imp_x < mid, imp_a, imp_b)\n\n            # Compute the impedance.\n            ξ = ξ_min + imp_y * (ξ_max - ξ_min)\n            ξ = jnp.clip(ξ, ξ_min, ξ_max)\n            ξ = jnp.where(imp_x > 1.0, ξ_max, ξ)\n\n            # Compute the spring and damper parameters during runtime from the\n            # impedance and other contact parameters.\n            K = 1 / (ξ_max * Ω * ζ) ** 2\n            D = 2 / (ξ_max * Ω)\n\n            # If the user specifies K and D and they are negative, the computed `a_ref`\n            # becomes something more similar to a classic Baumgarte regularization.\n            K = jnp.where(K < 0, -K / ξ_max**2, K)\n            D = jnp.where(D < 0, -D / ξ_max, D)\n\n            # Compute the reference acceleration.\n            a_ref = -(D * vel + K * ξ * pos)\n\n            return ξ, a_ref, K, D\n\n        def compute_row(\n            *,\n            link_idx: jtp.Int,\n            pos: jtp.Vector,\n            vel: jtp.Vector,\n        ) -> tuple[jtp.Vector, jtp.Matrix, jtp.Vector, jtp.Vector]:\n\n            # Compute the reference acceleration.\n            ξ, a_ref, K, D = imp_aref(pos=pos, vel=vel)\n\n            # Compute the regularization term.\n            R = (\n                (2 * μ**2 * (1 - ξ) / (ξ + 1e-12))\n                * (1 + μ**2)\n                @ jnp.linalg.inv(M_L[link_idx, :3, :3])\n            )\n\n            # Return the computed values, setting them to zero in case of no contact.\n            is_active = (pos.dot(pos) > 0).astype(float)\n            return jax.tree.map(\n                lambda x: jnp.atleast_1d(x) * is_active, (a_ref, R, K, D)\n            )\n\n        a_ref, R, K, D = jax.tree.map(\n            f=jnp.concatenate,\n            tree=(\n                *jax.vmap(compute_row)(\n                    link_idx=parent_link_idx_of_enabled_collidable_points,\n                    pos=position_constraint,\n                    vel=velocity_constraint,\n                ),\n            ),\n        )\n\n        return a_ref, R, K, D\n"
  },
  {
    "path": "src/jaxsim/rbda/contacts/rigid.py",
    "content": "from __future__ import annotations\n\nimport dataclasses\nfrom typing import Any\n\nimport jax\nimport jax.numpy as jnp\nimport jax_dataclasses\nimport qpax\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim import logging\nfrom jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr\n\nfrom . import common\nfrom .common import ContactModel, ContactsParams\n\ntry:\n    from typing import Self\nexcept ImportError:\n    from typing_extensions import Self\n\n\n@jax_dataclasses.pytree_dataclass\nclass RigidContactsParams(ContactsParams):\n    \"\"\"Parameters of the rigid contacts model.\"\"\"\n\n    # Static friction coefficient\n    mu: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array(0.5, dtype=float)\n    )\n\n    # Baumgarte proportional term\n    K: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array(0.0, dtype=float)\n    )\n\n    # Baumgarte derivative term\n    D: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array(0.0, dtype=float)\n    )\n\n    def __hash__(self) -> int:\n        from jaxsim.utils.wrappers import HashedNumpyArray\n\n        return hash(\n            (\n                HashedNumpyArray.hash_of_array(self.mu),\n                HashedNumpyArray.hash_of_array(self.K),\n                HashedNumpyArray.hash_of_array(self.D),\n            )\n        )\n\n    def __eq__(self, other: RigidContactsParams) -> bool:\n        if not isinstance(other, RigidContactsParams):\n            return False\n\n        return hash(self) == hash(other)\n\n    @classmethod\n    def build(\n        cls: type[Self],\n        *,\n        mu: jtp.FloatLike | None = None,\n        K: jtp.FloatLike | None = None,\n        D: jtp.FloatLike | None = None,\n        **kwargs,\n    ) -> Self:\n        \"\"\"Create a `RigidContactParams` instance.\"\"\"\n\n        return cls(\n            mu=jnp.array(\n                mu\n                if mu is not None\n                else cls.__dataclass_fields__[\"mu\"].default_factory()\n            ).astype(float),\n            K=jnp.array(\n                K if K is not None else cls.__dataclass_fields__[\"K\"].default_factory()\n            ).astype(float),\n            D=jnp.array(\n                D if D is not None else cls.__dataclass_fields__[\"D\"].default_factory()\n            ).astype(float),\n        )\n\n    def valid(self) -> jtp.BoolLike:\n        \"\"\"Check if the parameters are valid.\"\"\"\n        return bool(\n            jnp.all(self.mu >= 0.0)\n            and jnp.all(self.K >= 0.0)\n            and jnp.all(self.D >= 0.0)\n        )\n\n\n@jax_dataclasses.pytree_dataclass\nclass RigidContacts(ContactModel):\n    \"\"\"Rigid contacts model.\"\"\"\n\n    regularization_delassus: jax_dataclasses.Static[float] = dataclasses.field(\n        default=1e-6, kw_only=True\n    )\n\n    _solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field(\n        default=(\"solver_tol\",), kw_only=True\n    )\n    _solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field(\n        default=(1e-3,), kw_only=True\n    )\n\n    @property\n    def solver_options(self) -> dict[str, Any]:\n        \"\"\"Get the solver options as a dictionary.\"\"\"\n\n        return dict(\n            zip(\n                self._solver_options_keys,\n                self._solver_options_values,\n                strict=True,\n            )\n        )\n\n    @classmethod\n    def build(\n        cls: type[Self],\n        regularization_delassus: jtp.FloatLike | None = None,\n        solver_options: dict[str, Any] | None = None,\n        **kwargs,\n    ) -> Self:\n        \"\"\"\n        Create a `RigidContacts` instance with specified parameters.\n\n        Args:\n            regularization_delassus:\n                The regularization term to add to the diagonal of the Delassus matrix.\n            solver_options: The options to pass to the QP solver.\n            **kwargs: Extra arguments which are ignored.\n\n        Returns:\n            The `RigidContacts` instance.\n        \"\"\"\n\n        if kwargs:\n            logging.warning(msg=f\"Ignoring extra arguments: {kwargs}\")\n\n        # Get the default solver options.\n        default_solver_options = dict(\n            zip(cls._solver_options_keys, cls._solver_options_values, strict=True)\n        )\n\n        # Create the solver options to set by combining the default solver options\n        # with the user-provided solver options.\n        solver_options = default_solver_options | (\n            solver_options if solver_options is not None else {}\n        )\n\n        # Make sure that the solver options are hashable.\n        # We need to check this because the solver options are static.\n        try:\n            hash(tuple(solver_options.values()))\n        except TypeError as exc:\n            raise ValueError(\n                \"The values of the solver options must be hashable.\"\n            ) from exc\n\n        return cls(\n            regularization_delassus=float(\n                regularization_delassus\n                if regularization_delassus is not None\n                else cls.__dataclass_fields__[\"regularization_delassus\"].default\n            ),\n            _solver_options_keys=tuple(solver_options.keys()),\n            _solver_options_values=tuple(solver_options.values()),\n            **kwargs,\n        )\n\n    @staticmethod\n    def compute_impact_velocity(\n        inactive_collidable_points: jtp.ArrayLike,\n        M: jtp.MatrixLike,\n        J_WC: jtp.MatrixLike,\n        generalized_velocity: jtp.VectorLike,\n    ) -> jtp.Vector:\n        \"\"\"\n        Return the new velocity of the system after a potential impact.\n\n        Args:\n            inactive_collidable_points: The activation state of the collidable points.\n            M: The mass matrix of the system (in mixed representation).\n            J_WC: The Jacobian matrix of the collidable points (in mixed representation).\n            generalized_velocity: The generalized velocity of the system.\n\n        Note:\n            The mass matrix `M`, the Jacobian `J_WC`, and the generalized velocity `generalized_velocity`\n            must be expressed in the same velocity representation.\n        \"\"\"\n\n        # Compute system velocity after impact maintaining zero linear velocity of active points.\n        sl = jnp.s_[:, 0:3, :]\n        Jl_WC = J_WC[sl]\n\n        # Zero out the jacobian rows of inactive points.\n        Jl_WC = jnp.vstack(\n            jnp.where(\n                inactive_collidable_points[:, jnp.newaxis, jnp.newaxis],\n                jnp.zeros_like(Jl_WC),\n                Jl_WC,\n            )\n        )\n\n        A = jnp.vstack(\n            [\n                jnp.hstack([M, -Jl_WC.T]),\n                jnp.hstack([Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))]),\n            ]\n        )\n        b = jnp.hstack([M @ generalized_velocity, jnp.zeros(Jl_WC.shape[0])])\n\n        BW_ν_post_impact = jnp.linalg.lstsq(A, b)[0]\n\n        return BW_ν_post_impact[0 : M.shape[0]]\n\n    @jax.jit\n    @js.common.named_scope\n    def compute_contact_forces(\n        self,\n        model: js.model.JaxSimModel,\n        data: js.data.JaxSimModelData,\n        *,\n        link_forces: jtp.MatrixLike | None = None,\n        joint_force_references: jtp.VectorLike | None = None,\n    ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:\n        \"\"\"\n        Compute the contact forces.\n\n        Args:\n            model: The model to consider.\n            data: The data of the considered model.\n            link_forces:\n                Optional `(n_links, 6)` matrix of external forces acting on the links,\n                expressed in the same representation of data.\n            joint_force_references:\n                Optional `(n_joints,)` vector of joint forces.\n\n        Returns:\n            A tuple containing as first element the computed contact forces.\n        \"\"\"\n\n        # Get the indices of the enabled collidable points.\n        indices_of_enabled_collidable_points = (\n            model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points\n        )\n\n        n_collidable_points = len(indices_of_enabled_collidable_points)\n\n        link_forces = jnp.atleast_2d(\n            jnp.array(link_forces, dtype=float).squeeze()\n            if link_forces is not None\n            else jnp.zeros((model.number_of_links(), 6))\n        )\n\n        joint_force_references = jnp.atleast_1d(\n            jnp.array(joint_force_references, dtype=float).squeeze()\n            if joint_force_references is not None\n            else jnp.zeros((model.number_of_joints(),))\n        )\n\n        # Build a references object to simplify converting link forces.\n        references = js.references.JaxSimModelReferences.build(\n            model=model,\n            data=data,\n            velocity_representation=data.velocity_representation,\n            link_forces=link_forces,\n            joint_force_references=joint_force_references,\n        )\n\n        # Compute the position and linear velocities (mixed representation) of\n        # all enabled collidable points belonging to the robot.\n        position, velocity = js.contact.collidable_point_kinematics(\n            model=model, data=data\n        )\n\n        # Compute the penetration depth and velocity of the collidable points.\n        # Note that this function considers the penetration in the normal direction.\n        δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))(\n            position, velocity, model.terrain\n        )\n\n        W_H_C = js.contact.transforms(model=model, data=data)\n\n        with (\n            references.switch_velocity_representation(VelRepr.Mixed),\n            data.switch_velocity_representation(VelRepr.Mixed),\n        ):\n            # Compute kin-dyn quantities used in the contact model.\n            BW_ν = data.generalized_velocity\n\n            M_inv = js.model.free_floating_mass_matrix_inverse(model=model, data=data)\n\n            J_WC = js.contact.jacobian(model=model, data=data)\n            J̇_WC = js.contact.jacobian_derivative(model=model, data=data)\n\n            # Compute the generalized free acceleration.\n            BW_ν̇_free = jnp.hstack(\n                js.model.forward_dynamics_aba(\n                    model=model,\n                    data=data,\n                    link_forces=references.link_forces(model=model, data=data),\n                    joint_forces=references.joint_force_references(model=model),\n                )\n            )\n\n        # Compute the free linear acceleration of the collidable points.\n        # Since we use doubly-mixed jacobian, this corresponds to W_p̈_C.\n        free_contact_acc = _linear_acceleration_of_collidable_points(\n            BW_nu=BW_ν,\n            BW_nu_dot=BW_ν̇_free,\n            CW_J_WC_BW=J_WC,\n            CW_J_dot_WC_BW=J̇_WC,\n        ).flatten()\n\n        # Compute stabilization term.\n        baumgarte_term = _compute_baumgarte_stabilization_term(\n            inactive_collidable_points=(δ <= 0),\n            δ=δ,\n            δ_dot=δ_dot,\n            n=n̂,\n            K=model.contact_params.K,\n            D=model.contact_params.D,\n        ).flatten()\n\n        # Compute the Delassus matrix.\n        delassus_matrix = _delassus_matrix(M_inv=M_inv, J_WC=J_WC)\n\n        # Initialize regularization term of the Delassus matrix for\n        # better numerical conditioning.\n        Iε = self.regularization_delassus * jnp.eye(delassus_matrix.shape[0])\n\n        # Construct the quadratic cost function.\n        Q = delassus_matrix + Iε\n        q = free_contact_acc - baumgarte_term\n\n        # Construct the inequality constraints.\n        G = _compute_ineq_constraint_matrix(\n            inactive_collidable_points=(δ <= 0), mu=model.contact_params.mu\n        )\n        h_bounds = jnp.zeros(shape=(n_collidable_points * 6,))\n\n        # Construct the equality constraints.\n        A = jnp.zeros((0, 3 * n_collidable_points))\n        b = jnp.zeros((0,))\n\n        # Solve the following optimization problem with qpax:\n        #\n        # min_{x} 0.5 x⊤ Q x + q⊤ x\n        #\n        # s.t. A x = b\n        #      G x ≤ h\n        #\n        # TODO: add possibility to notify if the QP problem did not converge.\n        solution, _, _, _, converged, _ = qpax.solve_qp(  # noqa: RUF059\n            Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, **self.solver_options\n        )\n\n        # Reshape the optimized solution to be a matrix of 3D contact forces.\n        CW_fl_C = solution.reshape(-1, 3)\n\n        # Convert the contact forces from mixed to inertial-fixed representation.\n        W_f_C = jax.vmap(\n            lambda CW_fl_C, W_H_C: (\n                ModelDataWithVelocityRepresentation.other_representation_to_inertial(\n                    array=jnp.zeros(6).at[0:3].set(CW_fl_C),\n                    transform=W_H_C,\n                    other_representation=VelRepr.Mixed,\n                    is_force=True,\n                )\n            ),\n        )(CW_fl_C, W_H_C)\n\n        return W_f_C, {}\n\n    @jax.jit\n    @js.common.named_scope\n    def update_velocity_after_impact(\n        self: type[Self], model: js.model.JaxSimModel, data: js.data.JaxSimModelData\n    ) -> js.data.JaxSimModelData:\n        \"\"\"\n        Update the velocity after an impact.\n\n        Args:\n            model: The robot model considered by the contact model.\n            data: The data of the considered model.\n\n        Returns:\n            The updated data of the considered model.\n        \"\"\"\n\n        # Extract the indices corresponding to the enabled collidable points.\n        indices_of_enabled_collidable_points = (\n            model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points\n        )\n\n        W_p_C = js.contact.collidable_point_positions(model, data)[\n            indices_of_enabled_collidable_points\n        ]\n\n        # Compute the penetration depth of the collidable points.\n        δ, *_ = jax.vmap(\n            common.compute_penetration_data,\n            in_axes=(0, 0, None),\n        )(W_p_C, jnp.zeros_like(W_p_C), model.terrain)\n\n        with data.switch_velocity_representation(VelRepr.Mixed):\n            J_WC = js.contact.jacobian(model, data)[\n                indices_of_enabled_collidable_points\n            ]\n            M = js.model.free_floating_mass_matrix(model, data)\n            BW_ν_pre_impact = data.generalized_velocity\n\n            # Compute the impact velocity.\n            # It may be discontinuous in case new contacts are made.\n            BW_ν_post_impact = RigidContacts.compute_impact_velocity(\n                generalized_velocity=BW_ν_pre_impact,\n                inactive_collidable_points=(δ <= 0),\n                M=M,\n                J_WC=J_WC,\n            )\n\n            BW_ν_post_impact_inertial = data.other_representation_to_inertial(\n                array=BW_ν_post_impact[0:6],\n                other_representation=VelRepr.Mixed,\n                transform=data._base_transform.at[0:3, 0:3].set(jnp.eye(3)),\n                is_force=False,\n            )\n\n        # Reset the generalized velocity.\n        data = dataclasses.replace(\n            data,\n            _base_linear_velocity=BW_ν_post_impact_inertial[0:3],\n            _base_angular_velocity=BW_ν_post_impact_inertial[3:6],\n            _joint_velocities=BW_ν_post_impact[6:],\n        )\n\n        return data\n\n    def update_contact_state(\n        self: type[Self], old_contact_state: dict[str, jtp.Array]\n    ) -> dict[str, jtp.Array]:\n        \"\"\"\n        Update the contact state.\n\n        Args:\n            old_contact_state: The old contact state.\n\n        Returns:\n            The updated contact state.\n        \"\"\"\n\n        return {}\n\n\n@staticmethod\ndef _delassus_matrix(\n    M_inv: jtp.MatrixLike,\n    J_WC: jtp.MatrixLike,\n) -> jtp.Matrix:\n\n    sl = jnp.s_[:, 0:3, :]\n    J_WC_lin = jnp.vstack(J_WC[sl])\n\n    delassus_matrix = J_WC_lin @ M_inv @ J_WC_lin.T\n    return delassus_matrix\n\n\n@jax.jit\n@js.common.named_scope\ndef _compute_ineq_constraint_matrix(\n    inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the inequality constraint matrix for a single collidable point.\n\n    Rows 0-3: enforce the friction pyramid constraint,\n    Row 4: last one is for the non negativity of the vertical force\n    Row 5: contact complementarity condition\n    \"\"\"\n    G_single_point = jnp.array(\n        [\n            [1, 0, -mu],\n            [0, 1, -mu],\n            [-1, 0, -mu],\n            [0, -1, -mu],\n            [0, 0, -1],\n            [0, 0, 0],\n        ]\n    )\n    G = jnp.tile(G_single_point, (len(inactive_collidable_points), 1, 1))\n    G = G.at[:, 5, 2].set(inactive_collidable_points)\n\n    G = jax.scipy.linalg.block_diag(*G)\n    return G\n\n\n@jax.jit\n@js.common.named_scope\ndef _linear_acceleration_of_collidable_points(\n    BW_nu: jtp.ArrayLike,\n    BW_nu_dot: jtp.ArrayLike,\n    CW_J_WC_BW: jtp.MatrixLike,\n    CW_J_dot_WC_BW: jtp.MatrixLike,\n) -> jtp.Matrix:\n\n    BW_ν = BW_nu\n    BW_ν̇ = BW_nu_dot\n    CW_J̇_WC_BW = CW_J_dot_WC_BW\n\n    # Compute the linear acceleration of the collidable points.\n    # Since we use doubly-mixed jacobians, this corresponds to W_p̈_C.\n    CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇\n\n    CW_a_WC = CW_a_WC.reshape(-1, 6)\n    return CW_a_WC[:, 0:3].squeeze()\n\n\n@jax.jit\n@js.common.named_scope\ndef _compute_baumgarte_stabilization_term(\n    inactive_collidable_points: jtp.ArrayLike,\n    δ: jtp.ArrayLike,\n    δ_dot: jtp.ArrayLike,\n    n: jtp.ArrayLike,\n    K: jtp.FloatLike,\n    D: jtp.FloatLike,\n) -> jtp.Array:\n\n    return jnp.where(\n        inactive_collidable_points[:, jnp.newaxis],\n        jnp.zeros_like(n),\n        (K * δ + D * δ_dot)[:, jnp.newaxis] * n,\n    )\n"
  },
  {
    "path": "src/jaxsim/rbda/contacts/soft.py",
    "content": "from __future__ import annotations\n\nimport dataclasses\nimport functools\n\nimport jax\nimport jax.numpy as jnp\nimport jax_dataclasses\n\nimport jaxsim.api as js\nimport jaxsim.math\nimport jaxsim.typing as jtp\nfrom jaxsim import logging\nfrom jaxsim.terrain import Terrain\n\nfrom . import common\n\ntry:\n    from typing import Self\nexcept ImportError:\n    from typing_extensions import Self\n\n\n@jax_dataclasses.pytree_dataclass\nclass SoftContactsParams(common.ContactsParams):\n    \"\"\"Parameters of the soft contacts model.\"\"\"\n\n    K: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array(1e6, dtype=float)\n    )\n\n    D: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array(2000, dtype=float)\n    )\n\n    mu: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array(0.5, dtype=float)\n    )\n\n    p: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array(0.5, dtype=float)\n    )\n\n    q: jtp.Float = dataclasses.field(\n        default_factory=lambda: jnp.array(0.5, dtype=float)\n    )\n\n    def __hash__(self) -> int:\n\n        from jaxsim.utils.wrappers import HashedNumpyArray\n\n        return hash(\n            (\n                HashedNumpyArray.hash_of_array(self.K),\n                HashedNumpyArray.hash_of_array(self.D),\n                HashedNumpyArray.hash_of_array(self.mu),\n                HashedNumpyArray.hash_of_array(self.p),\n                HashedNumpyArray.hash_of_array(self.q),\n            )\n        )\n\n    def __eq__(self, other: SoftContactsParams) -> bool:\n\n        if not isinstance(other, SoftContactsParams):\n            return False\n\n        return hash(self) == hash(other)\n\n    @classmethod\n    def build(\n        cls: type[Self],\n        *,\n        K: jtp.FloatLike = 1e6,\n        D: jtp.FloatLike = 2_000,\n        mu: jtp.FloatLike = 0.5,\n        p: jtp.FloatLike = 0.5,\n        q: jtp.FloatLike = 0.5,\n        **kwargs,\n    ) -> Self:\n        \"\"\"\n        Create a SoftContactsParams instance with specified parameters.\n\n        Args:\n            K: The stiffness parameter.\n            D: The damping parameter of the soft contacts model.\n            mu: The static friction coefficient.\n            p:\n                The exponent p corresponding to the damping-related non-linearity\n                of the Hunt/Crossley model.\n            q:\n                The exponent q corresponding to the spring-related non-linearity\n                of the Hunt/Crossley model\n            **kwargs: Additional parameters to pass to the contact model.\n\n        Returns:\n            A SoftContactsParams instance with the specified parameters.\n        \"\"\"\n\n        return SoftContactsParams(\n            K=jnp.array(K, dtype=float),\n            D=jnp.array(D, dtype=float),\n            mu=jnp.array(mu, dtype=float),\n            p=jnp.array(p, dtype=float),\n            q=jnp.array(q, dtype=float),\n        )\n\n    def valid(self) -> jtp.BoolLike:\n        \"\"\"\n        Check if the parameters are valid.\n\n        Returns:\n            `True` if the parameters are valid, `False` otherwise.\n        \"\"\"\n\n        return jnp.hstack(\n            [\n                self.K >= 0.0,\n                self.D >= 0.0,\n                self.mu >= 0.0,\n                self.p >= 0.0,\n                self.q >= 0.0,\n            ]\n        ).all()\n\n\n@jax_dataclasses.pytree_dataclass\nclass SoftContacts(common.ContactModel):\n    \"\"\"Soft contacts model.\"\"\"\n\n    @classmethod\n    def build(\n        cls: type[Self],\n        **kwargs,\n    ) -> Self:\n        \"\"\"\n        Create a `SoftContacts` instance with specified parameters.\n\n        Args:\n            **kwargs: Additional parameters to pass to the contact model.\n\n        Returns:\n            The `SoftContacts` instance.\n        \"\"\"\n\n        if kwargs:\n            logging.warning(msg=f\"Ignoring extra arguments: {kwargs}\")\n\n        return cls(**kwargs)\n\n    @classmethod\n    def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]:\n        \"\"\"\n        Build zero state variables of the contact model.\n        \"\"\"\n\n        # Initialize the material deformation to zero.\n        tangential_deformation = jnp.zeros(\n            shape=(len(model.kin_dyn_parameters.contact_parameters.body), 3),\n            dtype=float,\n        )\n\n        return {\"tangential_deformation\": tangential_deformation}\n\n    def update_contact_state(\n        self: type[Self], old_contact_state: dict[str, jtp.Array]\n    ) -> dict[str, jtp.Array]:\n        \"\"\"\n        Update the contact state.\n\n        Args:\n            old_contact_state: The old contact state.\n\n        Returns:\n            The updated contact state.\n        \"\"\"\n\n        return {\"tangential_deformation\": old_contact_state[\"m_dot\"]}\n\n    def update_velocity_after_impact(\n        self: type[Self], model: js.model.JaxSimModel, data: js.data.JaxSimModelData\n    ) -> js.data.JaxSimModelData:\n        \"\"\"\n        Update the velocity after an impact.\n\n        Args:\n            model: The robot model considered by the contact model.\n            data: The data of the considered model.\n\n        Returns:\n            The updated data of the considered model.\n        \"\"\"\n\n        return data\n\n    @staticmethod\n    @functools.partial(jax.jit, static_argnames=(\"terrain\",))\n    def hunt_crossley_contact_model(\n        position: jtp.VectorLike,\n        velocity: jtp.VectorLike,\n        tangential_deformation: jtp.VectorLike,\n        terrain: Terrain,\n        K: jtp.FloatLike,\n        D: jtp.FloatLike,\n        mu: jtp.FloatLike,\n        p: jtp.FloatLike = 0.5,\n        q: jtp.FloatLike = 0.5,\n    ) -> tuple[jtp.Vector, jtp.Vector]:\n        \"\"\"\n        Compute the contact force using the Hunt/Crossley model.\n\n        Args:\n            position: The position of the collidable point.\n            velocity: The velocity of the collidable point.\n            tangential_deformation: The material deformation of the collidable point.\n            terrain: The terrain model.\n            K: The stiffness parameter.\n            D: The damping parameter of the soft contacts model.\n            mu: The static friction coefficient.\n            p:\n                The exponent p corresponding to the damping-related non-linearity\n                of the Hunt/Crossley model.\n            q:\n                The exponent q corresponding to the spring-related non-linearity\n                of the Hunt/Crossley model\n\n        Returns:\n            A tuple containing the computed contact force and the derivative of the\n            material deformation.\n        \"\"\"\n\n        # Convert the input vectors to arrays.\n        W_p_C = jnp.array(position, dtype=float).squeeze()\n        W_ṗ_C = jnp.array(velocity, dtype=float).squeeze()\n        m = jnp.array(tangential_deformation, dtype=float).squeeze()\n\n        # Use symbol for the static friction.\n        μ = mu\n\n        # Compute the penetration depth, its rate, and the considered terrain normal.\n        δ, δ̇, n̂ = common.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=terrain)\n\n        # There are few operations like computing the norm of a vector with zero length\n        # or computing the square root of zero that are problematic in an AD context.\n        # To avoid these issues, we introduce a small tolerance ε to their arguments\n        # and make sure that we do not check them against zero directly.\n        ε = jnp.finfo(float).eps\n\n        # Compute the powers of the penetration depth.\n        # Inject ε to address AD issues in differentiating the square root when\n        #  p and q are fractional.\n        δp = jnp.power(δ + ε, p)\n        δq = jnp.power(δ + ε, q)\n\n        # ========================\n        # Compute the normal force\n        # ========================\n\n        # Non-linear spring-damper model (Hunt/Crossley model).\n        # This is the force magnitude along the direction normal to the terrain.\n        force_normal_mag = (K * δp) * δ + (D * δq) * δ̇\n\n        # Depending on the magnitude of δ̇, the normal force could be negative.\n        force_normal_mag = jnp.maximum(0.0, force_normal_mag)\n\n        # Compute the 3D linear force in C[W] frame.\n        f_normal = force_normal_mag * n̂\n\n        # ============================\n        # Compute the tangential force\n        # ============================\n\n        # Extract the tangential component of the velocity.\n        v_tangential = W_ṗ_C - jnp.dot(W_ṗ_C, n̂) * n̂\n\n        # Extract the normal and tangential components of the material deformation.\n        m_normal = jnp.dot(m, n̂) * n̂\n        m_tangential = m - jnp.dot(m, n̂) * n̂\n\n        # Compute the tangential force in the sticking case.\n        # Using the tangential component of the material deformation should not be\n        # necessary if the sticking-slipping transition occurs in a terrain area\n        # with a locally constant normal. However, this assumption is not true in\n        # general, especially for highly uneven terrains.\n        f_tangential = -((K * δp) * m_tangential + (D * δq) * v_tangential)\n\n        # Detect the contact type (sticking or slipping).\n        # Note that if there is no contact, sticking is set to True, and this detail\n        # is exploited in the computation of the `contact_status` variable.\n        sticking = jnp.logical_or(\n            δ <= 0, f_tangential.dot(f_tangential) <= (μ * force_normal_mag) ** 2\n        )\n\n        # Compute the direction of the tangential force.\n        # To prevent dividing by zero, we use a switch statement.\n        norm = jaxsim.math.safe_norm(f_tangential)\n        f_tangential_direction = f_tangential / (\n            norm + jnp.finfo(float).eps * (norm == 0)\n        )\n\n        # Project the tangential force to the friction cone if slipping.\n        f_tangential = jnp.where(\n            sticking,\n            f_tangential,\n            jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction,\n        )\n\n        # Set the tangential force to zero if there is no contact.\n        f_tangential = jnp.where(δ <= 0, jnp.zeros(3), f_tangential)\n\n        # =====================================\n        # Compute the material deformation rate\n        # =====================================\n\n        # Compute the derivative of the material deformation.\n        # Note that we included an additional relaxation of `m_normal` in the\n        # sticking case, so that the normal deformation that could have accumulated\n        # from a previous slipping phase can relax to zero.\n        ṁ_no_contact = -(K / D) * m\n        ṁ_sticking = v_tangential - (K / D) * m_normal\n        ṁ_slipping = -(f_tangential + (K * δp) * m_tangential) / (D * δq)\n\n        # Compute the contact status:\n        # 0: slipping\n        # 1: sticking\n        # 2: no contact\n        contact_status = sticking.astype(int)\n        contact_status += (δ <= 0).astype(int)\n\n        # Select the right material deformation rate depending on the contact status.\n        ṁ = jax.lax.select_n(contact_status, ṁ_slipping, ṁ_sticking, ṁ_no_contact)\n\n        # ==========================================\n        # Compute and return the final contact force\n        # ==========================================\n\n        # Sum the normal and tangential forces.\n        CW_fl = f_normal + f_tangential\n\n        return CW_fl, ṁ\n\n    @staticmethod\n    @functools.partial(jax.jit, static_argnames=(\"terrain\",))\n    def compute_contact_force(\n        position: jtp.VectorLike,\n        velocity: jtp.VectorLike,\n        tangential_deformation: jtp.VectorLike,\n        parameters: SoftContactsParams,\n        terrain: Terrain,\n    ) -> tuple[jtp.Vector, jtp.Vector]:\n        \"\"\"\n        Compute the contact force.\n\n        Args:\n            position: The position of the collidable point.\n            velocity: The velocity of the collidable point.\n            tangential_deformation: The material deformation of the collidable point.\n            parameters: The parameters of the soft contacts model.\n            terrain: The terrain model.\n\n        Returns:\n            A tuple containing the computed contact force and the derivative of the\n            material deformation.\n        \"\"\"\n\n        CW_fl, ṁ = SoftContacts.hunt_crossley_contact_model(\n            position=position,\n            velocity=velocity,\n            tangential_deformation=tangential_deformation,\n            terrain=terrain,\n            K=parameters.K,\n            D=parameters.D,\n            mu=parameters.mu,\n            p=parameters.p,\n            q=parameters.q,\n        )\n\n        # Pack a mixed 6D force.\n        CW_f = jnp.hstack([CW_fl, jnp.zeros(3)])\n\n        # Compute the 6D force transform from the mixed to the inertial-fixed frame.\n        W_Xf_CW = jaxsim.math.Adjoint.from_quaternion_and_translation(\n            translation=jnp.array(position), inverse=True\n        ).T\n\n        # Compute the 6D force in the inertial-fixed frame.\n        W_f = W_Xf_CW @ CW_f\n\n        return W_f, ṁ\n\n    @staticmethod\n    @jax.jit\n    def compute_contact_forces(\n        model: js.model.JaxSimModel,\n        data: js.data.JaxSimModelData,\n    ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:\n        \"\"\"\n        Compute the contact forces.\n\n        Args:\n            model: The model to consider.\n            data: The data of the considered model.\n\n        Returns:\n            A tuple containing as first element the computed contact forces, and as\n            second element a dictionary with derivative of the material deformation.\n        \"\"\"\n\n        # Get the indices of the enabled collidable points.\n        indices_of_enabled_collidable_points = (\n            model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points\n        )\n\n        # Compute the position and linear velocities (mixed representation) of\n        # all the collidable points belonging to the robot and extract the ones\n        # for the enabled collidable points.\n        W_p_C, W_ṗ_C = js.contact.collidable_point_kinematics(model=model, data=data)\n\n        # Extract the material deformation corresponding to the collidable points.\n        m = (\n            data.contact_state[\"tangential_deformation\"]\n            if \"tangential_deformation\" in data.contact_state\n            else jnp.zeros_like(W_p_C)\n        )\n\n        m_enabled = m[indices_of_enabled_collidable_points]\n\n        # Initialize the tangential deformation rate array for every collidable point.\n        ṁ = jnp.zeros_like(m)\n\n        # Compute the contact forces only for the enabled collidable points.\n        # Since we treat them as independent, we can vmap the computation.\n        W_f, ṁ_enabled = jax.vmap(\n            lambda p, v, m: SoftContacts.compute_contact_force(\n                position=p,\n                velocity=v,\n                tangential_deformation=m,\n                parameters=model.contact_params,\n                terrain=model.terrain,\n            )\n        )(W_p_C, W_ṗ_C, m_enabled)\n\n        ṁ = ṁ.at[indices_of_enabled_collidable_points].set(ṁ_enabled)\n\n        return W_f, {\"m_dot\": ṁ}\n"
  },
  {
    "path": "src/jaxsim/rbda/crba.py",
    "content": "import jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\n\nfrom . import utils\n\n\ndef crba(\n    model: js.model.JaxSimModel,\n    *,\n    joint_positions: jtp.Vector,\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the free-floating mass matrix using the Composite Rigid-Body Algorithm (CRBA).\n\n    Args:\n        model: The model to consider.\n        joint_positions: The positions of the joints.\n\n    Returns:\n        The free-floating mass matrix of the model in body-fixed representation.\n    \"\"\"\n\n    _, _, s, _, _, _, _, _, _, _ = utils.process_inputs(\n        model=model, joint_positions=joint_positions\n    )\n\n    # Get the 6D spatial inertia matrices of all links.\n    Mc = js.model.link_spatial_inertia_matrices(model=model)\n\n    # Get the parent array λ(i).\n    # Note: λ(0) must not be used, it's initialized to -1.\n    λ = model.kin_dyn_parameters.parent_array\n\n    # Compute the parent-to-child adjoints of the joints.\n    # These transforms define the relative kinematics of the entire model, including\n    # the base transform for both floating-base and fixed-base models.\n    i_X_λi = model.kin_dyn_parameters.joint_transforms(\n        joint_positions=s, base_transform=jnp.eye(4)\n    )\n\n    # Extract the joint motion subspaces.\n    S = model.kin_dyn_parameters.motion_subspaces\n\n    # Allocate the buffer of transforms link -> base.\n    i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))\n    i_X_0 = i_X_0.at[0].set(jnp.eye(6))\n\n    # ====================\n    # Propagate kinematics\n    # ====================\n\n    ForwardPassCarry = tuple[jtp.Matrix]\n    forward_pass_carry: ForwardPassCarry = (i_X_0,)\n\n    def propagate_kinematics(\n        carry: ForwardPassCarry, i: jtp.Int\n    ) -> tuple[ForwardPassCarry, None]:\n\n        (i_X_0,) = carry\n\n        i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]\n        i_X_0 = i_X_0.at[i].set(i_X_0_i)\n\n        return (i_X_0,), None\n\n    (i_X_0,), _ = (\n        jax.lax.scan(\n            f=propagate_kinematics,\n            init=forward_pass_carry,\n            xs=jnp.arange(start=1, stop=model.number_of_links()),\n        )\n        if model.number_of_links() > 1\n        else [(i_X_0,), None]\n    )\n\n    # ===================\n    # Compute mass matrix\n    # ===================\n\n    M = jnp.zeros(shape=(6 + model.dofs(), 6 + model.dofs()))\n\n    BackwardPassCarry = tuple[jtp.Matrix, jtp.Matrix]\n    backward_pass_carry: BackwardPassCarry = (Mc, M)\n\n    def backward_pass(\n        carry: BackwardPassCarry, i: jtp.Int\n    ) -> tuple[BackwardPassCarry, None]:\n\n        ii = i - 1\n        Mc, M = carry\n\n        Mc_λi = Mc[λ[i]] + i_X_λi[i].T @ Mc[i] @ i_X_λi[i]\n        Mc = Mc.at[λ[i]].set(Mc_λi)\n\n        Fi = Mc[i] @ S[i]\n        M_ii = S[i].T @ Fi\n        M = M.at[ii + 6, ii + 6].set(M_ii.squeeze())\n\n        j = i\n\n        FakeWhileCarry = tuple[jtp.Int, jtp.Vector, jtp.Matrix]\n        fake_while_carry = (j, Fi, M)\n\n        # This internal for loop implements the while loop of the CRBA algorithm\n        # to compute off-diagonal blocks of the mass matrix M.\n        # In pseudocode it is implemented as a while loop. However, in order to enable\n        # applying reverse-mode AD, we implement it as a nested for loop with a fixed\n        # number of iterations and a branching model to skip for loop iterations.\n        def fake_while_loop(\n            carry: FakeWhileCarry, i: jtp.Int\n        ) -> tuple[FakeWhileCarry, None]:\n\n            def compute(carry: FakeWhileCarry) -> FakeWhileCarry:\n\n                j, Fi, M = carry\n\n                Fi = i_X_λi[j].T @ Fi\n                j = λ[j]\n\n                M_ij = Fi.T @ S[j]\n\n                jj = j - 1\n                M = M.at[ii + 6, jj + 6].set(M_ij.squeeze())\n                M = M.at[jj + 6, ii + 6].set(M_ij.squeeze())\n\n                return j, Fi, M\n\n            j, _, _ = carry\n\n            j, Fi, M = jax.lax.cond(\n                pred=jnp.logical_and(i == λ[j], λ[j] > 0),\n                true_fun=compute,\n                false_fun=lambda carry: carry,\n                operand=carry,\n            )\n\n            return (j, Fi, M), None\n\n        (j, Fi, M), _ = (\n            jax.lax.scan(\n                f=fake_while_loop,\n                init=fake_while_carry,\n                xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),\n            )\n            if model.number_of_links() > 1\n            else [(j, Fi, M), None]\n        )\n\n        Fi = i_X_0[j].T @ Fi\n\n        M = M.at[0:6, ii + 6].set(Fi.squeeze())\n        M = M.at[ii + 6, 0:6].set(Fi.squeeze())\n\n        return (Mc, M), None\n\n    # This scan performs the backward pass to compute Mbj, Mjb and Mjj, that\n    # also includes a fake while loop implemented with a scan and two cond.\n    (Mc, M), _ = (\n        jax.lax.scan(\n            f=backward_pass,\n            init=backward_pass_carry,\n            xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),\n        )\n        if model.number_of_links() > 1\n        else [(Mc, M), None]\n    )\n\n    # Store the locked 6D rigid-body inertia matrix Mbb ∈ ℝ⁶ˣ⁶.\n    M = M.at[0:6, 0:6].set(Mc[0])\n\n    return M\n"
  },
  {
    "path": "src/jaxsim/rbda/forward_kinematics.py",
    "content": "import jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim.math import Adjoint\n\nfrom . import utils\n\n\ndef forward_kinematics_model(\n    model: js.model.JaxSimModel,\n    *,\n    base_position: jtp.VectorLike,\n    base_quaternion: jtp.VectorLike,\n    joint_positions: jtp.VectorLike,\n    base_linear_velocity_inertial: jtp.VectorLike,\n    base_angular_velocity_inertial: jtp.VectorLike,\n    joint_velocities: jtp.VectorLike,\n    joint_transforms: jtp.MatrixLike,\n) -> tuple[jtp.Array, jtp.Array]:\n    \"\"\"\n    Compute the forward kinematics.\n\n    Args:\n        model: The model to consider.\n        base_position: The position of the base link.\n        base_quaternion: The quaternion of the base link.\n        joint_positions: The positions of the joints.\n        base_linear_velocity_inertial: The linear velocity of the base link in inertial-fixed representation.\n        base_angular_velocity_inertial: The angular velocity of the base link in inertial-fixed representation.\n        joint_velocities: The velocities of the joints.\n        joint_transforms: The parent-to-child transforms of the joints.\n\n    Returns:\n        A tuple containing the SE(3) transforms and the 6D velocities of all links.\n    \"\"\"\n\n    _, _, _, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs(\n        model=model,\n        base_position=base_position,\n        base_quaternion=base_quaternion,\n        joint_positions=joint_positions,\n        base_linear_velocity=base_linear_velocity_inertial,\n        base_angular_velocity=base_angular_velocity_inertial,\n        joint_velocities=joint_velocities,\n    )\n\n    # Get the parent array λ(i).\n    # Note: λ(0) must not be used, it's initialized to -1.\n    λ = model.kin_dyn_parameters.parent_array\n\n    # Extract the parent-to-child adjoints of the joints.\n    i_X_λi = jnp.asarray(joint_transforms)\n\n    # Allocate the buffer of transforms world -> link and initialize the base pose.\n    W_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))\n    W_X_i = W_X_i.at[0].set(Adjoint.inverse(i_X_λi[0]))\n\n    # Allocate buffer of 6D inertial-fixed velocities and initialize the base velocity.\n    W_v_Wi = jnp.zeros(shape=(model.number_of_links(), 6))\n    W_v_Wi = W_v_Wi.at[0].set(W_v_WB)\n\n    # Extract the joint motion subspaces.\n    S = model.kin_dyn_parameters.motion_subspaces\n\n    # ========================\n    # Propagate the kinematics\n    # ========================\n\n    PropagateKinematicsCarry = tuple[jtp.Matrix, jtp.Matrix]\n    propagate_kinematics_carry: PropagateKinematicsCarry = (W_X_i, W_v_Wi)\n\n    def propagate_kinematics(\n        carry: PropagateKinematicsCarry, i: jtp.Int\n    ) -> tuple[PropagateKinematicsCarry, None]:\n\n        ii = i - 1\n        W_X_i, W_v_Wi = carry\n\n        # Compute the parent to child 6D transform.\n        λi_X_i = Adjoint.inverse(adjoint=i_X_λi[i])\n\n        # Compute the world to child 6D transform.\n        W_Xi_i = W_X_i[λ[i]] @ λi_X_i\n        W_X_i = W_X_i.at[i].set(W_Xi_i)\n\n        # Propagate the 6D velocity.\n        W_vi_Wi = W_v_Wi[λ[i]] + W_X_i[i] @ (S[i] * ṡ[ii]).squeeze()\n        W_v_Wi = W_v_Wi.at[i].set(W_vi_Wi)\n\n        return (W_X_i, W_v_Wi), None\n\n    (W_X_i, W_v_Wi), _ = (\n        jax.lax.scan(\n            f=propagate_kinematics,\n            init=propagate_kinematics_carry,\n            xs=jnp.arange(start=1, stop=model.number_of_links()),\n        )\n        if model.number_of_links() > 1\n        else [(W_X_i, W_v_Wi), None]\n    )\n\n    return jax.vmap(Adjoint.to_transform)(W_X_i), W_v_Wi\n"
  },
  {
    "path": "src/jaxsim/rbda/forward_kinematics_parallel.py",
    "content": "import math\n\nimport jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim.math import Adjoint\n\nfrom . import utils\n\n\ndef forward_kinematics_model_parallel(\n    model: js.model.JaxSimModel,\n    *,\n    base_position: jtp.VectorLike,\n    base_quaternion: jtp.VectorLike,\n    joint_positions: jtp.VectorLike,\n    base_linear_velocity_inertial: jtp.VectorLike,\n    base_angular_velocity_inertial: jtp.VectorLike,\n    joint_velocities: jtp.VectorLike,\n    joint_transforms: jtp.MatrixLike,\n) -> tuple[jtp.Array, jtp.Array]:\n    \"\"\"\n    Compute forward kinematics using pointer jumping on the kinematic tree.\n\n    Uses an associative binary operator on transform-velocity pairs to\n    compute all world-frame transforms and velocities in O(log D) parallel\n    steps, where D is the tree depth.\n\n    The interface and semantics are identical to\n    :func:`forward_kinematics_model`, but parallelized via pointer jumping.\n    \"\"\"\n\n    _, _, _, W_v_WB, ṡ, _, _, _, _, _ = utils.process_inputs(\n        model=model,\n        base_position=base_position,\n        base_quaternion=base_quaternion,\n        joint_positions=joint_positions,\n        base_linear_velocity=base_linear_velocity_inertial,\n        base_angular_velocity=base_angular_velocity_inertial,\n        joint_velocities=joint_velocities,\n    )\n\n    # Extract the parent-to-child adjoints of the joints.\n    i_X_λi = jnp.asarray(joint_transforms)\n\n    # Extract the joint motion subspaces.\n    S = model.kin_dyn_parameters.motion_subspaces\n\n    n = model.number_of_links()\n\n    # Compute local transforms λ(i)_X_i by inverting the child-to-parent adjoints.\n    L = jax.vmap(Adjoint.inverse)(i_X_λi)  # (n, 6, 6)\n\n    # Compute local velocity contributions.\n    ṡ_padded = jnp.concatenate([jnp.zeros(1), jnp.atleast_1d(ṡ.squeeze())])  # (n,)\n    vJ = (S * ṡ_padded[:, None, None]).squeeze(-1)  # (n, 6)\n    u = jnp.einsum(\"nij,nj->ni\", L, vJ)  # (n, 6)\n    u = u.at[0].set(W_v_WB)\n\n    # Get the parent array λ(i) with root self-loop.\n    # Note: λ(0) is set to 0 to enable root self-referencing.\n    ptr = jnp.asarray(model.kin_dyn_parameters.parent_array).at[0].set(0)\n    done = jnp.arange(n) == 0\n\n    # Number of pointer-jumping rounds.\n    n_levels = model.kin_dyn_parameters.level_nodes.shape[0]\n    n_rounds = max(1, math.ceil(math.log2(max(n_levels, 2))))\n\n    # ===============\n    # Pointer jumping\n    # ===============\n\n    # Each round composes the node state with its current pointer target,\n    # then doubles the jump distance. After ceil(log2 D) rounds every node\n    # has accumulated the full root-to-node transform and velocity.\n\n    def _pointer_jump(carry, _):\n        L, u, ptr, done = carry\n        need = ~done\n\n        L_par = L[ptr]\n        u_par = u[ptr]\n\n        # Associative compose.\n        L_new = jnp.where(need[:, None, None], L_par @ L, L)\n        u_new = jnp.where(\n            need[:, None],\n            u_par + jnp.einsum(\"nij,nj->ni\", L_par, u),\n            u,\n        )\n\n        ptr_new = jnp.where(need, ptr[ptr], ptr)\n        done_new = done | done[ptr]\n\n        return (L_new, u_new, ptr_new, done_new), None\n\n    (W_X_i, W_v_Wi, _, _), _ = (\n        jax.lax.scan(\n            f=_pointer_jump,\n            init=(L, u, ptr, done),\n            xs=jnp.arange(n_rounds),\n        )\n        if n > 1\n        else ((L, u, ptr, done), None)\n    )\n\n    return jax.vmap(Adjoint.to_transform)(W_X_i), W_v_Wi\n"
  },
  {
    "path": "src/jaxsim/rbda/jacobian.py",
    "content": "import jax\nimport jax.numpy as jnp\nimport numpy as np\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim.math import Adjoint, Cross\n\nfrom . import utils\n\n\ndef jacobian(\n    model: js.model.JaxSimModel,\n    *,\n    link_index: jtp.Int,\n    joint_positions: jtp.VectorLike,\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the free-floating Jacobian of a link.\n\n    Args:\n        model: The model to consider.\n        link_index: The index of the link for which to compute the Jacobian matrix.\n        joint_positions: The positions of the joints.\n\n    Returns:\n        The free-floating left-trivialized Jacobian of the link :math:`{}^L J_{W,L/B}`.\n    \"\"\"\n\n    _, _, s, _, _, _, _, _, _, _ = utils.process_inputs(\n        model=model, joint_positions=joint_positions\n    )\n\n    # Get the parent array λ(i).\n    # Note: λ(0) must not be used, it's initialized to -1.\n    λ = model.kin_dyn_parameters.parent_array\n\n    # Compute the parent-to-child adjoints of the joints.\n    # These transforms define the relative kinematics of the entire model, including\n    # the base transform for both floating-base and fixed-base models.\n    i_X_λi = model.kin_dyn_parameters.joint_transforms(\n        joint_positions=s, base_transform=jnp.eye(4)\n    )\n\n    # Extract the joint motion subspaces.\n    S = model.kin_dyn_parameters.motion_subspaces\n\n    # Allocate the buffer of transforms link -> base.\n    i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))\n    i_X_0 = i_X_0.at[0].set(jnp.eye(6))\n\n    # ====================\n    # Propagate kinematics\n    # ====================\n\n    PropagateKinematicsCarry = tuple[jtp.Matrix]\n    propagate_kinematics_carry: PropagateKinematicsCarry = (i_X_0,)\n\n    def propagate_kinematics(\n        carry: PropagateKinematicsCarry, i: jtp.Int\n    ) -> tuple[PropagateKinematicsCarry, None]:\n\n        (i_X_0,) = carry\n\n        # Compute the base (0) to link (i) adjoint matrix.\n        # This works fine since we traverse the kinematic tree following the link\n        # indices assigned with BFS.\n        i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]\n        i_X_0 = i_X_0.at[i].set(i_X_0_i)\n\n        return (i_X_0,), None\n\n    (i_X_0,), _ = (\n        jax.lax.scan(\n            f=propagate_kinematics,\n            init=propagate_kinematics_carry,\n            xs=np.arange(start=1, stop=model.number_of_links()),\n        )\n        if model.number_of_links() > 1\n        else [(i_X_0,), None]\n    )\n\n    # ============================\n    # Compute doubly-left Jacobian\n    # ============================\n\n    J = jnp.zeros(shape=(6, 6 + model.dofs()))\n\n    Jb = i_X_0[link_index]\n    J = J.at[0:6, 0:6].set(Jb)\n\n    # To make JIT happy, we operate on a boolean version of κ(i).\n    # Checking if j ∈ κ(i) is equivalent to: κ_bool(j) is True.\n    κ_bool = model.kin_dyn_parameters.support_body_array_bool[link_index]\n\n    def compute_jacobian(J: jtp.Matrix, i: jtp.Int) -> tuple[jtp.Matrix, None]:\n\n        def update_jacobian(J: jtp.Matrix, i: jtp.Int) -> jtp.Matrix:\n\n            ii = i - 1\n\n            Js_i = i_X_0[link_index] @ Adjoint.inverse(i_X_0[i]) @ S[i]\n            J = J.at[0:6, 6 + ii].set(Js_i.squeeze())\n\n            return J\n\n        J = jax.lax.select(\n            pred=κ_bool[i],\n            on_true=update_jacobian(J, i),\n            on_false=J,\n        )\n\n        return J, None\n\n    L_J_WL_B, _ = (\n        jax.lax.scan(\n            f=compute_jacobian,\n            init=J,\n            xs=np.arange(start=1, stop=model.number_of_links()),\n        )\n        if model.number_of_links() > 1\n        else [J, None]\n    )\n\n    return L_J_WL_B\n\n\n@jax.jit\ndef jacobian_full_doubly_left(\n    model: js.model.JaxSimModel,\n    *,\n    joint_positions: jtp.VectorLike,\n) -> tuple[jtp.Matrix, jtp.Array]:\n    r\"\"\"\n    Compute the doubly-left full free-floating Jacobian of a model.\n\n    The full Jacobian is a 6x(6+n) matrix with all the columns filled.\n    It is useful to run the algorithm once, and then extract the link Jacobian by\n    filtering the columns of the full Jacobian using the support parent array\n    :math:`\\kappa(i)` of the link.\n\n    Args:\n        model: The model to consider.\n        joint_positions: The positions of the joints.\n\n    Returns:\n        The doubly-left full free-floating Jacobian of a model.\n    \"\"\"\n\n    _, _, s, _, _, _, _, _, _, _ = utils.process_inputs(\n        model=model, joint_positions=joint_positions\n    )\n\n    # Get the parent array λ(i).\n    # Note: λ(0) must not be used, it's initialized to -1.\n    λ = model.kin_dyn_parameters.parent_array\n\n    # Compute the parent-to-child adjoints of the joints.\n    # These transforms define the relative kinematics of the entire model, including\n    # the base transform for both floating-base and fixed-base models.\n    i_X_λi = model.kin_dyn_parameters.joint_transforms(\n        joint_positions=s, base_transform=jnp.eye(4)\n    )\n\n    # Extract the joint motion subspaces.\n    S = model.kin_dyn_parameters.motion_subspaces\n\n    # Allocate the buffer of transforms base -> link.\n    B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))\n    B_X_i = B_X_i.at[0].set(jnp.eye(6))\n\n    # =================================\n    # Compute doubly-left full Jacobian\n    # =================================\n\n    # Allocate the Jacobian matrix.\n    # The Jbb section of the doubly-left Jacobian is an identity matrix.\n    J = jnp.zeros(shape=(6, 6 + model.dofs()))\n    J = J.at[0:6, 0:6].set(jnp.eye(6))\n\n    ComputeFullJacobianCarry = tuple[jtp.Matrix, jtp.Matrix]\n    compute_full_jacobian_carry: ComputeFullJacobianCarry = (B_X_i, J)\n\n    def compute_full_jacobian(\n        carry: ComputeFullJacobianCarry, i: jtp.Int\n    ) -> tuple[ComputeFullJacobianCarry, None]:\n\n        ii = i - 1\n        B_X_i, J = carry\n\n        # Compute the base (0) to link (i) adjoint matrix.\n        B_Xi_i = B_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i])\n        B_X_i = B_X_i.at[i].set(B_Xi_i)\n\n        # Compute the ii-th column of the B_S_BL(s) matrix.\n        B_Sii_BL = B_Xi_i @ S[i]\n        J = J.at[0:6, 6 + ii].set(B_Sii_BL.squeeze())\n\n        return (B_X_i, J), None\n\n    (B_X_i, J), _ = (\n        jax.lax.scan(\n            f=compute_full_jacobian,\n            init=compute_full_jacobian_carry,\n            xs=np.arange(start=1, stop=model.number_of_links()),\n        )\n        if model.number_of_links() > 1\n        else [(B_X_i, J), None]\n    )\n\n    # Convert adjoints to SE(3) transforms.\n    # Returning them here prevents calling FK in case the output representation\n    # of the Jacobian needs to be changed.\n    B_H_L = jax.vmap(Adjoint.to_transform)(B_X_i)\n\n    # Adjust shape of doubly-left free-floating full Jacobian.\n    B_J_full_WL_B = J.squeeze().astype(float)\n\n    return B_J_full_WL_B, B_H_L\n\n\ndef jacobian_derivative_full_doubly_left(\n    model: js.model.JaxSimModel,\n    *,\n    joint_positions: jtp.VectorLike,\n    joint_velocities: jtp.VectorLike,\n) -> tuple[jtp.Matrix, jtp.Array]:\n    r\"\"\"\n    Compute the derivative of the doubly-left full free-floating Jacobian of a model.\n\n    The derivative of the full Jacobian is a 6x(6+n) matrix with all the columns filled.\n    It is useful to run the algorithm once, and then extract the link Jacobian\n    derivative by filtering the columns of the full Jacobian using the support\n    parent array :math:`\\kappa(i)` of the link.\n\n    Args:\n        model: The model to consider.\n        joint_positions: The positions of the joints.\n        joint_velocities: The velocities of the joints.\n\n    Returns:\n        The derivative of the doubly-left full free-floating Jacobian of a model.\n    \"\"\"\n\n    _, _, s, _, ṡ, _, _, _, _, _ = utils.process_inputs(\n        model=model, joint_positions=joint_positions, joint_velocities=joint_velocities\n    )\n\n    # Get the parent array λ(i).\n    # Note: λ(0) must not be used, it's initialized to -1.\n    λ = model.kin_dyn_parameters.parent_array\n\n    # Compute the parent-to-child adjoints of the joints.\n    # These transforms define the relative kinematics of the entire model, including\n    # the base transform for both floating-base and fixed-base models.\n    i_X_λi = model.kin_dyn_parameters.joint_transforms(\n        joint_positions=s, base_transform=jnp.eye(4)\n    )\n\n    # Extract the joint motion subspaces.\n    S = model.kin_dyn_parameters.motion_subspaces\n\n    # Allocate the buffer of 6D transform base -> link.\n    B_X_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))\n    B_X_i = B_X_i.at[0].set(jnp.eye(6))\n\n    # Allocate the buffer of 6D transform derivatives base -> link.\n    B_Ẋ_i = jnp.zeros(shape=(model.number_of_links(), 6, 6))\n\n    # Allocate the buffer of the 6D link velocity in body-fixed representation.\n    B_v_Bi = jnp.zeros(shape=(model.number_of_links(), 6))\n\n    # Helper to compute the time derivative of the adjoint matrix.\n    def A_Ẋ_B(A_X_B: jtp.Matrix, B_v_AB: jtp.Vector) -> jtp.Matrix:\n        return A_X_B @ Cross.vx(B_v_AB).squeeze()\n\n    # ============================================\n    # Compute doubly-left full Jacobian derivative\n    # ============================================\n\n    # Allocate the Jacobian matrix.\n    J̇ = jnp.zeros(shape=(6, 6 + model.dofs()))\n\n    ComputeFullJacobianDerivativeCarry = tuple[\n        jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix\n    ]\n\n    compute_full_jacobian_derivative_carry: ComputeFullJacobianDerivativeCarry = (\n        B_v_Bi,\n        B_X_i,\n        B_Ẋ_i,\n        J̇,\n    )\n\n    def compute_full_jacobian_derivative(\n        carry: ComputeFullJacobianDerivativeCarry, i: jtp.Int\n    ) -> tuple[ComputeFullJacobianDerivativeCarry, None]:\n\n        ii = i - 1\n        B_v_Bi, B_X_i, B_Ẋ_i, J̇ = carry\n\n        # Compute the base (0) to link (i) adjoint matrix.\n        B_Xi_i = B_X_i[λ[i]] @ Adjoint.inverse(i_X_λi[i])\n        B_X_i = B_X_i.at[i].set(B_Xi_i)\n\n        # Compute the body-fixed velocity of the link.\n        B_vi_Bi = B_v_Bi[λ[i]] + B_X_i[i] @ S[i].squeeze() * ṡ[ii]\n        B_v_Bi = B_v_Bi.at[i].set(B_vi_Bi)\n\n        # Compute the base (0) to link (i) adjoint matrix derivative.\n        i_Xi_B = Adjoint.inverse(B_Xi_i)\n        B_Ẋi_i = A_Ẋ_B(A_X_B=B_Xi_i, B_v_AB=i_Xi_B @ B_vi_Bi)\n        B_Ẋ_i = B_Ẋ_i.at[i].set(B_Ẋi_i)\n\n        # Compute the ii-th column of the B_Ṡ_BL(s) matrix.\n        B_Ṡii_BL = B_Ẋ_i[i] @ S[i]\n        J̇ = J̇.at[0:6, 6 + ii].set(B_Ṡii_BL.squeeze())\n\n        return (B_v_Bi, B_X_i, B_Ẋ_i, J̇), None\n\n    (_, B_X_i, B_Ẋ_i, J̇), _ = (\n        jax.lax.scan(\n            f=compute_full_jacobian_derivative,\n            init=compute_full_jacobian_derivative_carry,\n            xs=np.arange(start=1, stop=model.number_of_links()),\n        )\n        if model.number_of_links() > 1\n        else [(_, B_X_i, B_Ẋ_i, J̇), None]\n    )\n\n    # Convert adjoints to SE(3) transforms.\n    # Returning them here prevents calling FK in case the output representation\n    # of the Jacobian needs to be changed.\n    B_H_L = jax.vmap(Adjoint.to_transform)(B_X_i)\n\n    # Adjust shape of doubly-left free-floating full Jacobian derivative.\n    B_J̇_full_WL_B = J̇.squeeze().astype(float)\n\n    return B_J̇_full_WL_B, B_H_L\n"
  },
  {
    "path": "src/jaxsim/rbda/kinematic_constraints.py",
    "content": "from __future__ import annotations\n\nimport jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr\nfrom jaxsim.api.kin_dyn_parameters import ConstraintMap\nfrom jaxsim.math.adjoint import Adjoint\nfrom jaxsim.math.rotation import Rotation\nfrom jaxsim.math.transform import Transform\n\n# Utility functions used for constraints computation. These functions duplicate part of the jaxsim.api.frame module for computational efficiency.\n# TODO: remove these functions when jaxsim.api.frame is optimized for batched computations.\n# See: https://github.com/gbionics/jaxsim/issues/451\n\n\ndef _compute_constraint_transforms_batched(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    constraints: ConstraintMap,\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the transformation matrices for kinematic constraints between pairs of frames.\n\n    Args:\n        model: The JaxSim model containing the robot description.\n        data: The model data containing current state information.\n        constraints: The constraint map containing frame indices and parent link information.\n\n    Returns:\n        A matrix with shape (n_constraints, 2, 4, 4) containing the transformation matrices\n        for each constraint pair. The second dimension contains [W_H_F1, W_H_F2] where\n        W_H_F1 and W_H_F2 are the world-to-frame transformation matrices.\n    \"\"\"\n    W_H_L = data._link_transforms\n\n    frame_idxs_1 = constraints.frame_idxs_1\n    frame_idxs_2 = constraints.frame_idxs_2\n\n    parent_link_idxs_1 = constraints.parent_link_idxs_1\n    parent_link_idxs_2 = constraints.parent_link_idxs_2\n\n    # Extract frame transforms\n    L_H_F1 = model.kin_dyn_parameters.frame_parameters.transform[\n        frame_idxs_1 - model.number_of_links()\n    ]\n    L_H_F2 = model.kin_dyn_parameters.frame_parameters.transform[\n        frame_idxs_2 - model.number_of_links()\n    ]\n\n    # Compute the homogeneous transformation matrices for the two frames\n    W_H_F1 = W_H_L[parent_link_idxs_1] @ L_H_F1\n    W_H_F2 = W_H_L[parent_link_idxs_2] @ L_H_F2\n\n    return jnp.stack([W_H_F1, W_H_F2], axis=1)\n\n\ndef _compute_constraint_jacobians_batched(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    constraints: ConstraintMap,\n    W_H_constraint_pairs: jtp.Matrix,\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the constraint Jacobian matrices for kinematic constraints in a batched manner.\n    Args:\n        model: The JaxSim model containing the robot description.\n        data: The model data containing current state information.\n        constraints: The constraint map containing frame indices and parent link information.\n        W_H_constraint_pairs: Transformation matrices for constraint frame pairs with shape\n                              (n_constraints, 2, 4, 4).\n\n    Returns:\n        A matrix with shape (n_constraints, 6, n_dofs) containing the constraint Jacobian\n        matrices.\n    \"\"\"\n\n    with data.switch_velocity_representation(VelRepr.Body):\n        # Doubly-left free-floating Jacobian.\n        L_J_WL_B = js.model.generalized_free_floating_jacobian(\n            model=model, data=data, output_vel_repr=VelRepr.Body\n        )\n\n        # Link transforms\n        W_H_L = data._link_transforms\n\n    def compute_frame_jacobian_mixed(L_J_WL, W_H_L, W_H_F, parent_link_index):\n        \"\"\"Compute the jacobian of a frame in mixed representation.\"\"\"\n        # Select the jacobian of the parent link\n        L_J_WL = L_J_WL[parent_link_index]\n\n        # Compute the jacobian of the frame in mixed representation\n        W_H_L = W_H_L[parent_link_index]\n        F_H_L = Transform.inverse(W_H_F) @ W_H_L\n        FW_H_F = W_H_F.at[0:3, 3].set(jnp.zeros(3))\n        FW_H_L = FW_H_F @ F_H_L\n        FW_X_L = Adjoint.from_transform(transform=FW_H_L)\n        FW_J_WL = FW_X_L @ L_J_WL\n        O_J_WL_I = FW_J_WL\n\n        return O_J_WL_I\n\n    def compute_constraint_jacobian(L_J_WL, W_H_F, constraint):\n        \"\"\"Compute the constraint jacobian for a single constraint pair.\"\"\"\n\n        J_WF1 = compute_frame_jacobian_mixed(\n            L_J_WL, W_H_L, W_H_F[0], constraint.parent_link_idxs_1\n        )\n        J_WF2 = compute_frame_jacobian_mixed(\n            L_J_WL, W_H_L, W_H_F[1], constraint.parent_link_idxs_2\n        )\n\n        return J_WF1 - J_WF2\n\n    # Vectorize the computation of constraint Jacobians\n    constraint_jacobians = jax.vmap(compute_constraint_jacobian, in_axes=(None, 0, 0))(\n        L_J_WL_B, W_H_constraint_pairs, constraints\n    )\n\n    return constraint_jacobians\n\n\ndef _compute_constraint_baumgarte_term(\n    J_constr: jtp.Matrix,\n    nu: jtp.Vector,\n    W_H_F_constr: jtp.Matrix,\n    constraint: ConstraintMap,\n) -> jtp.Vector:\n    \"\"\"\n    Compute the Baumgarte stabilization term for kinematic constraints.\n\n    The Baumgarte stabilization method is used to stabilize constraint violations\n    by adding proportional and derivative terms to the constraint equation. This\n    helps prevent constraint drift and improves numerical stability.\n\n    Args:\n        J_constr: The constraint Jacobian matrix with shape (6, n_dofs).\n        nu: The generalized velocity vector with shape (n_dofs,).\n        W_H_F_constr: Array containing the homogeneous transformation matrices\n                      of two frames [W_H_F1, W_H_F2] with respect to the world frame,\n                      with shape (2, 4, 4).\n        constraint: The constraint object containing stabilization gains K_P and K_D.\n\n    Returns:\n        The computed Baumgarte stabilization term.\n    \"\"\"\n    W_H_F1, W_H_F2 = W_H_F_constr\n\n    W_p_F1 = W_H_F1[0:3, 3]\n    W_p_F2 = W_H_F2[0:3, 3]\n\n    W_R_F1 = W_H_F1[0:3, 0:3]\n    W_R_F2 = W_H_F2[0:3, 0:3]\n\n    K_P = constraint.K_P\n    K_D = constraint.K_D\n\n    vel_error = J_constr @ nu\n    position_error = W_p_F1 - W_p_F2\n    R_error = W_R_F2.T @ W_R_F1\n    orientation_error = Rotation.log_vee(R_error)\n\n    baumgarte_term = (\n        K_P * jnp.concatenate([position_error, orientation_error]) + K_D * vel_error\n    )\n\n    return baumgarte_term\n\n\n@jax.jit\n@js.common.named_scope\ndef compute_constraint_wrenches(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    *,\n    joint_force_references: jtp.VectorLike | None = None,\n    link_forces_inertial: jtp.MatrixLike | None = None,\n    regularization: jtp.Float = 1e-3,\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the constraint wrenches for kinematic constraints.\n\n    This function solves the constraint forces needed to satisfy kinematic constraints\n    between pairs of frames. It uses the Baumgarte stabilization method and computes\n    the constraint wrenches in inertial representation.\n\n    Args:\n        model: The JaxSim model.\n        data: The model data.\n        joint_force_references: Optional joint force/torque references to apply. If None,\n                               zero forces are used.\n        link_forces_inertial: Optional link forces applied in inertial representation.\n                             If None, zero forces are used.\n        regularization: Regularization parameter for the constraint solver to improve\n                       numerical stability. Default is 1e-3.\n\n    Returns:\n        Array with shape (n_constraints, 2, 6) containing constraint wrench pairs\n        in inertial representation. Each constraint produces two equal and opposite\n        wrenches applied to the constrained frames.\n    \"\"\"\n\n    # Retrieve the kinematic constraints, if any.\n    kin_constraints = model.kin_dyn_parameters.constraints\n\n    n_kin_constraints = (\n        6 * kin_constraints.frame_idxs_1.shape[0]\n        if kin_constraints is not None and kin_constraints.frame_idxs_1.shape[0] > 0\n        else 0\n    )\n\n    # Return empty results if no constraints exist\n    if n_kin_constraints == 0:\n        return jnp.zeros((0, 2, 6))\n\n    # Build joint forces if not provided\n    τ_references = (\n        jnp.asarray(joint_force_references, dtype=float)\n        if joint_force_references is not None\n        else jnp.zeros_like(data.joint_positions)\n    )\n\n    # Build link forces if not provided\n    W_f_L = (\n        jnp.atleast_2d(jnp.array(link_forces_inertial).squeeze())\n        if link_forces_inertial is not None\n        else jnp.zeros((model.number_of_links(), 6))\n    ).astype(float)\n\n    # Create references object for handling different velocity representations\n    references = js.references.JaxSimModelReferences.build(\n        model=model,\n        joint_force_references=τ_references,\n        link_forces=W_f_L,\n        velocity_representation=VelRepr.Inertial,\n    )\n\n    with (\n        data.switch_velocity_representation(VelRepr.Mixed),\n        references.switch_velocity_representation(VelRepr.Mixed),\n    ):\n        BW_ν = data.generalized_velocity\n\n        # Compute free acceleration without constraints\n        BW_ν̇_free = jnp.hstack(\n            js.model.forward_dynamics_aba(\n                model=model,\n                data=data,\n                link_forces=references.link_forces(model=model, data=data),\n                joint_forces=references.joint_force_references(model=model),\n            )\n        )\n\n        # Compute mass matrix\n        M_inv = js.model.free_floating_mass_matrix_inverse(model=model, data=data)\n\n        W_H_constr_pairs = _compute_constraint_transforms_batched(\n            model=model,\n            data=data,\n            constraints=kin_constraints,\n        )\n\n        # Compute constraint jacobians\n        J_constr = _compute_constraint_jacobians_batched(\n            model=model,\n            data=data,\n            constraints=kin_constraints,\n            W_H_constraint_pairs=W_H_constr_pairs,\n        )\n\n        # Compute Baumgarte stabilization term\n        constr_baumgarte_term = jnp.ravel(\n            jax.vmap(\n                _compute_constraint_baumgarte_term,\n                in_axes=(0, None, 0, 0),\n            )(\n                J_constr,\n                BW_ν,\n                W_H_constr_pairs,\n                kin_constraints,\n            ),\n        )\n\n        # Stack constraint jacobians\n        J_constr = jnp.vstack(J_constr)\n\n        # Compute Delassus matrix for constraints\n        G_constraints = J_constr @ M_inv @ J_constr.T\n\n        # Compute constraint acceleration\n        # TODO: add J̇_constr with efficient computation\n        CW_al_free_constr = J_constr @ BW_ν̇_free\n\n        # Setup constraint optimization problem\n        constraint_regularization = regularization * jnp.ones(n_kin_constraints)\n        R = jnp.diag(constraint_regularization)\n        A = G_constraints + R\n        b = CW_al_free_constr + constr_baumgarte_term\n\n        # Solve for constraint forces\n        kin_constr_wrench_mixed = jnp.linalg.solve(A, -b).reshape(-1, 6)\n\n    def transform_wrenches_to_inertial(wrench, transform_pair):\n        \"\"\"\n        Transform wrench pairs in inertial representation.\n\n        Args:\n            wrench: Wrench vector with shape (6,).\n            transform_pair: Pair of transformation matrices [W_H_F1, W_H_F2]\n\n        Returns:\n            Stack of transformed wrenches with shape (2, 6).\n        \"\"\"\n        W_H_F1, W_H_F2 = transform_pair[0], transform_pair[1]\n        wrench_F1 = wrench\n        wrench_F2 = -wrench\n\n        # Create wrench pair directly\n        # Transform both at once\n        wrench_F1_inertial = (\n            ModelDataWithVelocityRepresentation.other_representation_to_inertial(\n                array=wrench_F1,\n                transform=W_H_F1,\n                other_representation=VelRepr.Mixed,\n                is_force=True,\n            )\n        )\n        wrench_F2_inertial = (\n            ModelDataWithVelocityRepresentation.other_representation_to_inertial(\n                array=wrench_F2,\n                transform=W_H_F2,\n                other_representation=VelRepr.Mixed,\n                is_force=True,\n            )\n        )\n\n        return jnp.stack([wrench_F1_inertial, wrench_F2_inertial])\n\n    kin_constr_wrench_pairs_inertial = jax.vmap(transform_wrenches_to_inertial)(\n        kin_constr_wrench_mixed, W_H_constr_pairs\n    )\n\n    return kin_constr_wrench_pairs_inertial\n"
  },
  {
    "path": "src/jaxsim/rbda/mass_inverse.py",
    "content": "import jax\nimport jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\n\n\ndef mass_inverse(\n    model: js.model.JaxSimModel,\n    *,\n    joint_transforms: jtp.MatrixLike,\n) -> jtp.Matrix:\n    \"\"\"\n    Compute the inverse of the mass matrix using an ABA-like algorithm.\n    The implementation follows the approach described in https://laas.hal.science/hal-01790934v2.\n\n    Args:\n        model: The model to consider.\n        joint_transforms: The parent-to-child transforms of the joints.\n\n    Returns:\n        The inverse of the mass matrix.\n    \"\"\"\n\n    # Get the 6D spatial inertia matrices of all links.\n    I_A = js.model.link_spatial_inertia_matrices(model=model)\n\n    # Get the parent array λ(i).\n    #   λ[0] ~ -1 (world)\n    #   λ[i] = parent link index for link i.\n    λ = model.kin_dyn_parameters.parent_array\n\n    # Extract the parent-to-child adjoints of the joints.\n    # These transforms define the relative kinematics of the entire model, including\n    # the base transform for both floating-base and fixed-base models.\n    i_X_λi = jnp.asarray(joint_transforms)\n\n    # Extract the joint motion subspaces.\n    S = model.kin_dyn_parameters.motion_subspaces\n\n    NB = model.number_of_links()\n    N = model.number_of_joints()\n\n    # Total generalized velocities: 6 base + N.\n    nv = N + 6\n\n    # Allocate buffers.\n    F = jnp.zeros((NB, 6, nv), dtype=float)\n    P = jnp.zeros((NB, 6, nv), dtype=float)\n    U = jnp.zeros((NB, 6), dtype=float)\n    D = jnp.zeros((NB,), dtype=float)\n\n    # Pre-allocate mass matrix inverse\n    M_inv = jnp.zeros((nv, nv), dtype=float)\n\n    # Pre-compute indices.\n    idx_fwd = jnp.arange(1, NB)\n    idx_rev = jnp.arange(NB - 1, 0, -1)\n\n    # =============\n    # Backward Pass\n    # =============\n\n    BackwardPassCarry = tuple[\n        jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix\n    ]\n    backward_pass_carry: BackwardPassCarry = (I_A, F, U, D, M_inv)\n\n    def loop_backward_pass(\n        carry: BackwardPassCarry, i: jtp.Int\n    ) -> tuple[BackwardPassCarry, None]:\n        I_A, F, U, D, M_inv = carry\n\n        Si = jnp.squeeze(S[i], axis=-1)\n        Fi = F[i]\n        Xi = i_X_λi[i]\n        parent = λ[i]\n\n        Ui = I_A[i] @ Si\n        Di = jnp.dot(Si, Ui)\n\n        U = U.at[i].set(Ui)\n        D = D.at[i].set(Di)\n\n        # Row index in ν for joint i: 6 + (i - 1)\n        r = 6 + (i - 1)\n\n        Minv_row = M_inv[r]\n\n        # Diagonal element\n        Minv_row = Minv_row.at[r].add(1.0 / Di)\n\n        # Off-diagonals: Minv[r,:] -= (1/Di) * Sᵢᵀ Fᵢ\n        sTFi = jnp.einsum(\"s,sn->n\", Si, Fi)\n        Minv_row = Minv_row - sTFi / Di\n\n        M_inv = M_inv.at[r].set(Minv_row)\n\n        # Propagate to parent if any (parent >= 0)\n        def propagate(IA_F):\n            I_A_, F_ = IA_F\n\n            Ui_col = Ui[:, None]\n\n            # F_a_i = F_i + U_i * Minv[r,:]\n            Fa_i = Fi + Ui_col @ Minv_row[None, :]\n\n            # F_parent += Xᵢᵀ F_a_i\n            F_parent_new = F_[parent] + Xi.T @ Fa_i\n            F_ = F_.at[parent].set(F_parent_new)\n\n            # I_a_i = IAi - U_i D_i^{-1} U_iᵀ\n            Ia_i = I_A[i] - jnp.outer(Ui, Ui) / Di\n\n            # I_A[parent] += Xᵢᵀ I_a_i Xᵢ\n            I_parent_new = I_A_[parent] + Xi.T @ Ia_i @ Xi\n            I_A_ = I_A_.at[parent].set(I_parent_new)\n\n            return I_A_, F_\n\n        I_A, F = jax.lax.cond(\n            parent >= 0,\n            propagate,\n            lambda IA_F: IA_F,\n            (I_A, F),\n        )\n\n        return (I_A, F, U, D, M_inv), None\n\n    (I_A, F, U, D, M_inv), _ = jax.lax.scan(\n        loop_backward_pass, backward_pass_carry, idx_rev\n    )\n\n    S0 = jnp.eye(6, dtype=float)\n    U0 = I_A[0] @ S0\n    D0 = S0.T @ U0\n    D0_inv = jnp.linalg.inv(D0)\n\n    # Base rows 0..5 in ν\n    base_rows = slice(0, 6)\n\n    # Diagonal base block\n    M_inv = M_inv.at[base_rows, base_rows].add(D0_inv)\n\n    # Off-diagonal base contribution: M_inv[base,:] -= D0^{-T} F[0]\n    term0 = D0_inv.T @ F[0]\n    M_inv = M_inv.at[base_rows, :].add(-term0)\n\n    # ============\n    # Forward Pass\n    # ============\n\n    # Initialize P_0 = S0 * Minv[base,:] = I * Minv[base,:]\n    Minv_base = M_inv[base_rows, :]\n    P = P.at[0].set(Minv_base)\n\n    ForwardPassCarry = tuple[jtp.Matrix, jtp.Matrix]\n    forward_pass_carry: ForwardPassCarry = (M_inv, P)\n\n    def loop_forward_pass(\n        carry: ForwardPassCarry, i: jtp.Int\n    ) -> tuple[ForwardPassCarry, None]:\n        M_inv, P = carry\n\n        Si = jnp.squeeze(S[i], axis=-1)\n        Ui = U[i]\n        Di = D[i]\n        Xi = i_X_λi[i]\n        parent = λ[i]\n\n        P_parent = jax.lax.cond(\n            parent >= 0,\n            lambda P_: P_[parent],\n            lambda P_: jnp.zeros_like(P_[i]),\n            P,\n        )\n\n        # Row index in ν for joint i\n        r = 6 + (i - 1)\n\n        # Row update: M_inv[r,:] -= D_i^{-1} U_iᵀ Xᵢ P_parent\n        def update_row(Minv_):\n            X_P = Xi @ P_parent\n            UiT_XP = jnp.einsum(\"s,sn->n\", Ui, X_P)\n            Minv_row = Minv_[r, :] - UiT_XP / Di\n            return Minv_.at[r, :].set(Minv_row)\n\n        M_inv = jax.lax.cond(\n            parent >= 0,\n            update_row,\n            lambda Minv_: Minv_,\n            M_inv,\n        )\n\n        Minv_row = M_inv[r, :]\n\n        # P_i = S_i Minv[r,:] + Xᵢ P_parent\n        Pi = jnp.expand_dims(Si, 1) @ jnp.expand_dims(Minv_row, 0)\n        Pi = Pi + Xi @ P_parent\n\n        P = P.at[i].set(Pi)\n\n        return (M_inv, P), None\n\n    (M_inv, P), _ = jax.lax.scan(loop_forward_pass, forward_pass_carry, idx_fwd)\n\n    # Symmetrize numerically\n    M_inv = 0.5 * (M_inv + M_inv.T)\n\n    return M_inv\n"
  },
  {
    "path": "src/jaxsim/rbda/rnea.py",
    "content": "import jax\nimport jax.numpy as jnp\nimport jaxlie\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim.math import STANDARD_GRAVITY, Adjoint, Cross\n\nfrom . import utils\n\n\ndef rnea(\n    model: js.model.JaxSimModel,\n    *,\n    base_position: jtp.Vector,\n    base_quaternion: jtp.Vector,\n    joint_positions: jtp.Vector,\n    base_linear_velocity: jtp.Vector,\n    base_angular_velocity: jtp.Vector,\n    joint_velocities: jtp.Vector,\n    base_linear_acceleration: jtp.Vector | None = None,\n    base_angular_acceleration: jtp.Vector | None = None,\n    joint_accelerations: jtp.Vector | None = None,\n    joint_transforms: jtp.MatrixLike,\n    link_forces: jtp.Matrix | None = None,\n    standard_gravity: jtp.FloatLike = STANDARD_GRAVITY,\n) -> tuple[jtp.Vector, jtp.Vector]:\n    \"\"\"\n    Compute inverse dynamics using the Recursive Newton-Euler Algorithm (RNEA).\n\n    Args:\n        model: The model to consider.\n        base_position: The position of the base link.\n        base_quaternion: The quaternion of the base link.\n        joint_positions: The positions of the joints.\n        base_linear_velocity:\n            The linear velocity of the base link in inertial-fixed representation.\n        base_angular_velocity:\n            The angular velocity of the base link in inertial-fixed representation.\n        joint_velocities: The velocities of the joints.\n        base_linear_acceleration:\n            The linear acceleration of the base link in inertial-fixed representation.\n        base_angular_acceleration:\n            The angular acceleration of the base link in inertial-fixed representation.\n        joint_accelerations: The accelerations of the joints.\n        joint_transforms: The parent-to-child transforms of the joints.\n        link_forces:\n            The forces applied to the links expressed in the world frame.\n        standard_gravity: The standard gravity constant.\n\n    Returns:\n        A tuple containing the 6D force applied to the base link expressed in the\n        world frame and the joint forces that, when applied respectively to the base\n        link and joints, produce the given base and joint accelerations.\n    \"\"\"\n\n    W_p_B, W_Q_B, s, W_v_WB, ṡ, W_v̇_WB, s̈, _, W_f, W_g = utils.process_inputs(\n        model=model,\n        base_position=base_position,\n        base_quaternion=base_quaternion,\n        joint_positions=joint_positions,\n        base_linear_velocity=base_linear_velocity,\n        base_angular_velocity=base_angular_velocity,\n        joint_velocities=joint_velocities,\n        base_linear_acceleration=base_linear_acceleration,\n        base_angular_acceleration=base_angular_acceleration,\n        joint_accelerations=joint_accelerations,\n        link_forces=link_forces,\n        standard_gravity=standard_gravity,\n    )\n\n    W_g = jnp.atleast_2d(W_g).T\n    W_v_WB = jnp.atleast_2d(W_v_WB).T\n    W_v̇_WB = jnp.atleast_2d(W_v̇_WB).T\n\n    # Get the 6D spatial inertia matrices of all links.\n    M = js.model.link_spatial_inertia_matrices(model=model)\n\n    # Get the parent array λ(i).\n    # Note: λ(0) must not be used, it's initialized to -1.\n    λ = model.kin_dyn_parameters.parent_array\n\n    # Compute the base transform.\n    W_H_B = jaxlie.SE3.from_rotation_and_translation(\n        rotation=jaxlie.SO3(wxyz=W_Q_B),\n        translation=W_p_B,\n    )\n\n    # Compute 6D transforms of the base velocity.\n    W_X_B = W_H_B.adjoint()\n    B_X_W = W_H_B.inverse().adjoint()\n\n    # Extract the parent-to-child adjoints of the joints.\n    # These transforms define the relative kinematics of the entire model, including\n    # the base transform for both floating-base and fixed-base models.\n    # Ensure cached transforms are JAX arrays so they work with traced indices.\n    i_X_λi = jnp.asarray(joint_transforms)\n\n    # Extract the joint motion subspaces.\n    S = model.kin_dyn_parameters.motion_subspaces\n\n    # Allocate buffers.\n    v = jnp.zeros(shape=(model.number_of_links(), 6, 1))\n    a = jnp.zeros(shape=(model.number_of_links(), 6, 1))\n    f = jnp.zeros(shape=(model.number_of_links(), 6, 1))\n\n    # Allocate the buffer of transforms link -> base.\n    i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6))\n    i_X_0 = i_X_0.at[0].set(jnp.eye(6))\n\n    # Initialize the acceleration of the base link.\n    a_0 = -B_X_W @ W_g\n    a = a.at[0].set(a_0)\n\n    if model.floating_base():\n\n        # Base velocity v₀ in body-fixed representation.\n        v_0 = B_X_W @ W_v_WB\n        v = v.at[0].set(v_0)\n\n        # Base acceleration a₀ in body-fixed representation w/o gravity.\n        a_0 = B_X_W @ (W_v̇_WB - W_g)\n        a = a.at[0].set(a_0)\n\n        # Force applied to the base link that produce the base acceleration w/o gravity.\n        f_0 = (\n            M[0] @ a[0]\n            + Cross.vx_star(v[0]) @ M[0] @ v[0]\n            - W_X_B.T @ jnp.vstack(W_f[0])\n        )\n        f = f.at[0].set(f_0)\n\n    # ======\n    # Pass 1\n    # ======\n\n    ForwardPassCarry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]\n    forward_pass_carry: ForwardPassCarry = (v, a, i_X_0, f)\n\n    def forward_pass(\n        carry: ForwardPassCarry, i: jtp.Int\n    ) -> tuple[ForwardPassCarry, None]:\n\n        ii = i - 1\n        v, a, i_X_0, f = carry\n\n        # Project the joint velocity into its motion subspace.\n        vJ = S[i] * ṡ[ii]\n\n        # Propagate the link velocity.\n        v_i = i_X_λi[i] @ v[λ[i]] + vJ\n        v = v.at[i].set(v_i)\n\n        # Propagate the link acceleration.\n        a_i = i_X_λi[i] @ a[λ[i]] + S[i] * s̈[ii] + Cross.vx(v[i]) @ vJ\n        a = a.at[i].set(a_i)\n\n        # Compute the link-to-base transform.\n        i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]\n        i_X_0 = i_X_0.at[i].set(i_X_0_i)\n\n        # Compute link-to-world transform for the 6D force.\n        i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T\n\n        # Compute the force acting on the link.\n        f_i = (\n            M[i] @ a[i]\n            + Cross.vx_star(v[i]) @ M[i] @ v[i]\n            - i_Xf_W @ jnp.vstack(W_f[i])\n        )\n        f = f.at[i].set(f_i)\n\n        return (v, a, i_X_0, f), None\n\n    (v, a, i_X_0, f), _ = (\n        jax.lax.scan(\n            f=forward_pass,\n            init=forward_pass_carry,\n            xs=jnp.arange(start=1, stop=model.number_of_links()),\n        )\n        if model.number_of_links() > 1\n        else [(v, a, i_X_0, f), None]\n    )\n\n    # ======\n    # Pass 2\n    # ======\n\n    τ = jnp.zeros_like(s)\n\n    BackwardPassCarry = tuple[jtp.Vector, jtp.Matrix]\n    backward_pass_carry: BackwardPassCarry = (τ, f)\n\n    def backward_pass(\n        carry: BackwardPassCarry, i: jtp.Int\n    ) -> tuple[BackwardPassCarry, None]:\n\n        ii = i - 1\n        τ, f = carry\n\n        # Project the 6D force to the DoF of the joint.\n        τ_i = S[i].T @ f[i]\n        τ = τ.at[ii].set(τ_i.squeeze())\n\n        # Propagate the force to the parent link.\n        def update_f(f: jtp.Matrix) -> jtp.Matrix:\n\n            f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]\n            f = f.at[λ[i]].set(f_λi)\n\n            return f\n\n        f = jax.lax.cond(\n            pred=jnp.logical_or(λ[i] != 0, model.floating_base()),\n            true_fun=update_f,\n            false_fun=lambda f: f,\n            operand=f,\n        )\n\n        return (τ, f), None\n\n    (τ, f), _ = (\n        jax.lax.scan(\n            f=backward_pass,\n            init=backward_pass_carry,\n            xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),\n        )\n        if model.number_of_links() > 1\n        else [(τ, f), None]\n    )\n\n    # ==============\n    # Adjust outputs\n    # ==============\n\n    # Express the base 6D force in the world frame.\n    W_f0 = B_X_W.T @ f[0]\n\n    return W_f0.squeeze(), jnp.atleast_1d(τ.squeeze())\n"
  },
  {
    "path": "src/jaxsim/rbda/utils.py",
    "content": "import jax.numpy as jnp\n\nimport jaxsim.api as js\nimport jaxsim.typing as jtp\nfrom jaxsim import exceptions\nfrom jaxsim.math import STANDARD_GRAVITY\n\n\ndef process_inputs(\n    model: js.model.JaxSimModel,\n    *,\n    base_position: jtp.VectorLike | None = None,\n    base_quaternion: jtp.VectorLike | None = None,\n    joint_positions: jtp.VectorLike | None = None,\n    base_linear_velocity: jtp.VectorLike | None = None,\n    base_angular_velocity: jtp.VectorLike | None = None,\n    joint_velocities: jtp.VectorLike | None = None,\n    base_linear_acceleration: jtp.VectorLike | None = None,\n    base_angular_acceleration: jtp.VectorLike | None = None,\n    joint_accelerations: jtp.VectorLike | None = None,\n    joint_forces: jtp.VectorLike | None = None,\n    link_forces: jtp.MatrixLike | None = None,\n    standard_gravity: jtp.ScalarLike | None = None,\n) -> tuple[\n    jtp.Vector,\n    jtp.Vector,\n    jtp.Vector,\n    jtp.Vector,\n    jtp.Vector,\n    jtp.Vector,\n    jtp.Vector,\n    jtp.Vector,\n    jtp.Matrix,\n    jtp.Vector,\n]:\n    \"\"\"\n    Adjust the inputs to rigid-body dynamics algorithms.\n\n    Args:\n        model: The model to consider.\n        base_position: The position of the base link.\n        base_quaternion: The quaternion of the base link.\n        joint_positions: The positions of the joints.\n        base_linear_velocity: The linear velocity of the base link.\n        base_angular_velocity: The angular velocity of the base link.\n        joint_velocities: The velocities of the joints.\n        base_linear_acceleration: The linear acceleration of the base link.\n        base_angular_acceleration: The angular acceleration of the base link.\n        joint_accelerations: The accelerations of the joints.\n        joint_forces: The forces applied to the joints.\n        link_forces: The forces applied to the links.\n        standard_gravity: The standard gravity constant.\n\n    Returns:\n        The adjusted inputs.\n    \"\"\"\n\n    dofs = model.dofs()\n    nl = model.number_of_links()\n\n    # Floating-base position.\n    W_p_B = base_position\n    W_Q_B = base_quaternion\n    s = joint_positions\n\n    # Floating-base velocity in inertial-fixed representation.\n    W_vl_WB = base_linear_velocity\n    W_ω_WB = base_angular_velocity\n    ṡ = joint_velocities\n\n    # Floating-base acceleration in inertial-fixed representation.\n    W_v̇l_WB = base_linear_acceleration\n    W_ω̇_WB = base_angular_acceleration\n    s̈ = joint_accelerations\n\n    # System dynamics inputs.\n    f = link_forces\n    τ = joint_forces\n\n    # Fill missing data and adjust dimensions.\n    s = jnp.atleast_1d(s.squeeze()) if s is not None else jnp.zeros(dofs)\n    ṡ = jnp.atleast_1d(ṡ.squeeze()) if ṡ is not None else jnp.zeros(dofs)\n    s̈ = jnp.atleast_1d(s̈.squeeze()) if s̈ is not None else jnp.zeros(dofs)\n    τ = jnp.atleast_1d(τ.squeeze()) if τ is not None else jnp.zeros(dofs)\n    W_vl_WB = jnp.atleast_1d(W_vl_WB.squeeze()) if W_vl_WB is not None else jnp.zeros(3)\n    W_v̇l_WB = (\n        jnp.atleast_1d(W_v̇l_WB.squeeze()) if W_v̇l_WB is not None else jnp.zeros(3)\n    )\n    W_p_B = jnp.atleast_1d(W_p_B.squeeze()) if W_p_B is not None else jnp.zeros(3)\n    W_ω_WB = jnp.atleast_1d(W_ω_WB.squeeze()) if W_ω_WB is not None else jnp.zeros(3)\n    W_ω̇_WB = jnp.atleast_1d(W_ω̇_WB.squeeze()) if W_ω̇_WB is not None else jnp.zeros(3)\n    f = jnp.atleast_2d(f.squeeze()) if f is not None else jnp.zeros(shape=(nl, 6))\n    W_Q_B = (\n        jnp.atleast_1d(W_Q_B.squeeze())\n        if W_Q_B is not None\n        else jnp.array([1.0, 0, 0, 0])\n    )\n    standard_gravity = (\n        jnp.array(standard_gravity).squeeze()\n        if standard_gravity is not None\n        else STANDARD_GRAVITY\n    )\n\n    if s.shape != (dofs,):\n        raise ValueError(s.shape, dofs)\n\n    if ṡ.shape != (dofs,):\n        raise ValueError(ṡ.shape, dofs)\n\n    if s̈.shape != (dofs,):\n        raise ValueError(s̈.shape, dofs)\n\n    if τ.shape != (dofs,):\n        raise ValueError(τ.shape, dofs)\n\n    if W_p_B.shape != (3,):\n        raise ValueError(W_p_B.shape, (3,))\n\n    if W_vl_WB.shape != (3,):\n        raise ValueError(W_vl_WB.shape, (3,))\n\n    if W_ω_WB.shape != (3,):\n        raise ValueError(W_ω_WB.shape, (3,))\n\n    if W_v̇l_WB.shape != (3,):\n        raise ValueError(W_v̇l_WB.shape, (3,))\n\n    if W_ω̇_WB.shape != (3,):\n        raise ValueError(W_ω̇_WB.shape, (3,))\n\n    if f.shape != (nl, 6):\n        raise ValueError(f.shape, (nl, 6))\n\n    if W_Q_B.shape != (4,):\n        raise ValueError(W_Q_B.shape, (4,))\n\n    # Check that the quaternion does not contain NaN values.\n    exceptions.raise_value_error_if(\n        condition=jnp.isnan(W_Q_B).any(),\n        msg=\"A RBDA received a quaternion that contains NaN values.\",\n    )\n\n    # Check that the quaternion is unary since our RBDAs make this assumption in order\n    # to prevent introducing additional normalizations that would affect AD.\n    exceptions.raise_value_error_if(\n        condition=~jnp.allclose(W_Q_B.dot(W_Q_B), 1.0),\n        msg=\"A RBDA received a quaternion that is not normalized.\",\n    )\n\n    # Pack the 6D base velocity and acceleration.\n    W_v_WB = jnp.hstack([W_vl_WB, W_ω_WB])\n    W_v̇_WB = jnp.hstack([W_v̇l_WB, W_ω̇_WB])\n\n    # Create the 6D gravity acceleration.\n    W_g = jnp.array([0, 0, standard_gravity, 0, 0, 0])\n\n    return (\n        W_p_B.astype(float),\n        W_Q_B.astype(float),\n        s.astype(float),\n        W_v_WB.astype(float),\n        ṡ.astype(float),\n        W_v̇_WB.astype(float),\n        s̈.astype(float),\n        τ.astype(float),\n        f.astype(float),\n        W_g.astype(float),\n    )\n"
  },
  {
    "path": "src/jaxsim/terrain/__init__.py",
    "content": "from . import terrain\nfrom .terrain import FlatTerrain, PlaneTerrain, Terrain\n"
  },
  {
    "path": "src/jaxsim/terrain/terrain.py",
    "content": "from __future__ import annotations\n\nimport abc\nimport dataclasses\n\nimport jax.numpy as jnp\nimport jax_dataclasses\nimport numpy as np\n\nimport jaxsim.math\nimport jaxsim.typing as jtp\nfrom jaxsim import exceptions\n\n\nclass Terrain(abc.ABC):\n    \"\"\"\n    Base class for terrain models.\n\n    Attributes:\n        delta: The delta value used for numerical differentiation.\n    \"\"\"\n\n    delta = 0.010\n\n    @abc.abstractmethod\n    def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:\n        \"\"\"\n        Compute the height of the terrain at a specific (x, y) location.\n\n        Args:\n            x: The x-coordinate of the location.\n            y: The y-coordinate of the location.\n\n        Returns:\n            The height of the terrain at the specified location.\n        \"\"\"\n\n        pass\n\n    def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:\n        \"\"\"\n        Compute the normal vector of the terrain at a specific (x, y) location.\n\n        Args:\n            x: The x-coordinate of the location.\n            y: The y-coordinate of the location.\n\n        Returns:\n            The normal vector of the terrain surface at the specified location.\n        \"\"\"\n\n        # https://stackoverflow.com/a/5282364\n        h_xp = self.height(x=x + self.delta, y=y)\n        h_xm = self.height(x=x - self.delta, y=y)\n        h_yp = self.height(x=x, y=y + self.delta)\n        h_ym = self.height(x=x, y=y - self.delta)\n\n        n = jnp.array(\n            [(h_xm - h_xp) / (2 * self.delta), (h_ym - h_yp) / (2 * self.delta), 1.0]\n        )\n\n        return n / jaxsim.math.safe_norm(n, axis=-1)\n\n\n@jax_dataclasses.pytree_dataclass\nclass FlatTerrain(Terrain):\n    \"\"\"\n    Represents a terrain model with a flat surface and a constant height.\n    \"\"\"\n\n    _height: float = dataclasses.field(default=0.0, kw_only=True)\n\n    @staticmethod\n    def build(height: jtp.FloatLike = 0.0) -> FlatTerrain:\n        \"\"\"\n        Create a FlatTerrain instance with a specified height.\n\n        Args:\n            height: The height of the flat terrain.\n\n        Returns:\n            FlatTerrain: A FlatTerrain instance.\n        \"\"\"\n\n        return FlatTerrain(_height=float(height))\n\n    def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:\n        \"\"\"\n        Compute the height of the terrain at a specific (x, y) location.\n\n        Args:\n            x: The x-coordinate of the location.\n            y: The y-coordinate of the location.\n\n        Returns:\n            The height of the terrain at the specified location.\n        \"\"\"\n\n        return jnp.array(self._height, dtype=float)\n\n    def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:\n        \"\"\"\n        Compute the normal vector of the terrain at a specific (x, y) location.\n\n        Args:\n            x: The x-coordinate of the location.\n            y: The y-coordinate of the location.\n\n        Returns:\n            The normal vector of the terrain surface at the specified location.\n        \"\"\"\n\n        return jnp.array([0.0, 0.0, 1.0], dtype=float)\n\n    def __hash__(self) -> int:\n\n        return hash(self._height)\n\n    def __eq__(self, other: FlatTerrain) -> bool:\n\n        if not isinstance(other, FlatTerrain):\n            return False\n\n        return self._height == other._height\n\n\n@jax_dataclasses.pytree_dataclass\nclass PlaneTerrain(FlatTerrain):\n    \"\"\"\n    Represents a terrain model with a flat surface defined by a normal vector.\n    \"\"\"\n\n    _normal: tuple[float, float, float] = jax_dataclasses.field(\n        default=(0.0, 0.0, 1.0), kw_only=True\n    )\n\n    @staticmethod\n    def build(height: jtp.FloatLike = 0.0, *, normal: jtp.VectorLike) -> PlaneTerrain:\n        \"\"\"\n        Create a PlaneTerrain instance with a specified plane normal vector.\n\n        Args:\n            normal: The normal vector of the terrain plane.\n            height: The height of the plane over the origin.\n\n        Returns:\n            PlaneTerrain: A PlaneTerrain instance.\n        \"\"\"\n\n        normal = jnp.array(normal, dtype=float)\n        height = jnp.array(height, dtype=float)\n\n        if normal.shape != (3,):\n            msg = \"Expected a 3D vector for the plane normal, got '{}'.\"\n            raise ValueError(msg.format(normal.shape))\n\n        # Make sure that the plane normal is a unit vector.\n        normal = normal / jnp.linalg.norm(normal)\n\n        return PlaneTerrain(\n            _height=height.item(),\n            _normal=tuple(normal.tolist()),\n        )\n\n    def normal(\n        self, x: jtp.FloatLike | None = None, y: jtp.FloatLike | None = None\n    ) -> jtp.Vector:\n        \"\"\"\n        Compute the normal vector of the terrain at a specific (x, y) location.\n\n        Args:\n            x: The x-coordinate of the location.\n            y: The y-coordinate of the location.\n\n        Returns:\n            The normal vector of the terrain surface at the specified location.\n        \"\"\"\n\n        return jnp.array(self._normal, dtype=float)\n\n    def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:\n        \"\"\"\n        Compute the height of the terrain at a specific (x, y) location on a plane.\n\n        Args:\n            x: The x-coordinate of the location.\n            y: The y-coordinate of the location.\n\n        Returns:\n            The height of the terrain at the specified location on the plane.\n        \"\"\"\n\n        # Equation of the plane:      A x + B y + C z + D = 0\n        # Normal vector coordinates:  (A, B, C)\n        # The height over the origin: -D/C\n\n        # Get the plane equation coefficients from the terrain normal.\n        A, B, C = self._normal\n\n        exceptions.raise_value_error_if(\n            condition=jnp.allclose(C, 0.0),\n            msg=\"The z component of the normal cannot be zero.\",\n        )\n\n        # Compute the final coefficient D considering the terrain height.\n        D = -C * self._height\n\n        # Invert the plane equation to get the height at the given (x, y) coordinates.\n        return jnp.array(-(A * x + B * y + D) / C).astype(float)\n\n    def __hash__(self) -> int:\n\n        from jaxsim.utils.wrappers import HashedNumpyArray\n\n        return hash(\n            (\n                hash(self._height),\n                HashedNumpyArray.hash_of_array(\n                    array=np.array(self._normal, dtype=float)\n                ),\n            )\n        )\n\n    def __eq__(self, other: PlaneTerrain) -> bool:\n\n        if not isinstance(other, PlaneTerrain):\n            return False\n\n        if not (\n            np.allclose(self._height, other._height)\n            and np.allclose(\n                np.array(self._normal, dtype=float),\n                np.array(other._normal, dtype=float),\n            )\n        ):\n            return False\n\n        return True\n"
  },
  {
    "path": "src/jaxsim/typing.py",
    "content": "from collections.abc import Hashable\nfrom typing import Any, TypeVar\n\nimport jax\n\n# =========\n# JAX types\n# =========\n\nArray = jax.Array\nScalar = Array\nVector = Array\nMatrix = Array\n\nInt = Scalar\nBool = Scalar\nFloat = Scalar\n\nPyTree: object = (\n    dict[Hashable, TypeVar(\"PyTree\")]\n    | list[TypeVar(\"PyTree\")]\n    | tuple[TypeVar(\"PyTree\")]\n    | jax.Array\n    | Any\n    | None\n)\n\n# =======================\n# Mixed JAX / NumPy types\n# =======================\n\nArrayLike = jax.typing.ArrayLike | tuple\nScalarLike = int | float | Scalar | ArrayLike\nVectorLike = Vector | ArrayLike | tuple\nMatrixLike = Matrix | ArrayLike\n\nIntLike = int | Int | jax.typing.ArrayLike\nBoolLike = bool | Bool | jax.typing.ArrayLike\nFloatLike = float | Float | jax.typing.ArrayLike\n"
  },
  {
    "path": "src/jaxsim/utils/__init__.py",
    "content": "from jax_dataclasses._copy_and_mutate import _Mutability as Mutability\n\nfrom .jaxsim_dataclass import JaxsimDataclass\nfrom .tracing import not_tracing, tracing\nfrom .wrappers import HashedNumpyArray, HashlessObject\n"
  },
  {
    "path": "src/jaxsim/utils/jaxsim_dataclass.py",
    "content": "import abc\nimport contextlib\nimport dataclasses\nimport functools\nfrom collections.abc import Callable, Iterator, Sequence\nfrom typing import Any, ClassVar\n\nimport jax.flatten_util\nimport jax_dataclasses\n\nimport jaxsim.typing as jtp\n\nfrom . import Mutability\n\ntry:\n    from typing import Self\nexcept ImportError:\n    from typing_extensions import Self\n\n\n@jax_dataclasses.pytree_dataclass\nclass JaxsimDataclass(abc.ABC):\n    \"\"\"Class extending `jax_dataclasses.pytree_dataclass` instances with utilities.\"\"\"\n\n    # This attribute is set by jax_dataclasses\n    __mutability__: ClassVar[Mutability] = Mutability.FROZEN\n\n    @contextlib.contextmanager\n    def editable(self: Self, validate: bool = True) -> Iterator[Self]:\n        \"\"\"\n        Context manager to operate on a mutable copy of the object.\n\n        Args:\n            validate: Whether to validate the output PyTree upon exiting the context.\n\n        Yields:\n            A mutable copy of the object.\n\n        Note:\n            This context manager is useful to operate on an r/w copy of a PyTree making\n            sure that the output object does not trigger JIT recompilations.\n        \"\"\"\n\n        mutability = (\n            Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION\n        )\n\n        with self.copy().mutable_context(mutability=mutability) as obj:\n            yield obj\n\n    @contextlib.contextmanager\n    def mutable_context(\n        self: Self,\n        mutability: Mutability = Mutability.MUTABLE,\n        restore_after_exception: bool = True,\n    ) -> Iterator[Self]:\n        \"\"\"\n        Context manager to temporarily change the mutability of the object.\n\n        Args:\n            mutability: The mutability to set.\n            restore_after_exception:\n                Whether to restore the original object in case of an exception\n                occurring within the context.\n\n        Yields:\n            The object with the new mutability.\n\n        Note:\n            This context manager is useful to operate in place on a PyTree without\n            the need to make a copy while optionally keeping active the checks on\n            the PyTree structure, shapes, and dtypes.\n        \"\"\"\n\n        if restore_after_exception:\n            self_copy = self.copy()\n\n        original_mutability = self.mutability()\n\n        original_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self)\n        original_shapes = JaxsimDataclass.get_leaf_shapes(tree=self)\n        original_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self)\n        original_structure = jax.tree.structure(tree=self)\n\n        def restore_self() -> None:\n            self.set_mutability(mutability=Mutability.MUTABLE_NO_VALIDATION)\n            for f in dataclasses.fields(self_copy):\n                setattr(self, f.name, getattr(self_copy, f.name))\n\n        try:\n            self.set_mutability(mutability=mutability)\n            yield self\n\n            if mutability is not Mutability.MUTABLE_NO_VALIDATION:\n                new_structure = jax.tree.structure(tree=self)\n                if original_structure != new_structure:\n                    msg = \"Pytree structure has changed from {} to {}\"\n                    raise ValueError(msg.format(original_structure, new_structure))\n\n                new_shapes = JaxsimDataclass.get_leaf_shapes(tree=self)\n                if original_shapes != new_shapes:\n                    msg = \"Leaves shapes have changed from {} to {}\"\n                    raise ValueError(msg.format(original_shapes, new_shapes))\n\n                new_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self)\n                if original_dtypes != new_dtypes:\n                    msg = \"Leaves dtypes have changed from {} to {}\"\n                    raise ValueError(msg.format(original_dtypes, new_dtypes))\n\n                new_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self)\n                if original_weak_types != new_weak_types:\n                    msg = \"Leaves weak types have changed from {} to {}\"\n                    raise ValueError(msg.format(original_weak_types, new_weak_types))\n\n        except Exception as e:\n            if restore_after_exception:\n                restore_self()\n            self.set_mutability(original_mutability)\n            raise e\n\n        finally:\n            self.set_mutability(original_mutability)\n\n    @staticmethod\n    def get_leaf_shapes(tree: jtp.PyTree) -> tuple[tuple[int, ...] | None]:\n        \"\"\"\n        Get the leaf shapes of a PyTree.\n\n        Args:\n            tree: The PyTree to consider.\n\n        Returns:\n            A tuple containing the leaf shapes of the PyTree or `None` is the leaf is\n            not a numpy-like array.\n        \"\"\"\n\n        return tuple(\n            map(\n                lambda leaf: getattr(leaf, \"shape\", None),\n                jax.tree.leaves(tree),\n            )\n        )\n\n    @staticmethod\n    def get_leaf_dtypes(tree: jtp.PyTree) -> tuple:\n        \"\"\"\n        Get the leaf dtypes of a PyTree.\n\n        Args:\n            tree: The PyTree to consider.\n\n        Returns:\n            A tuple containing the leaf dtypes of the PyTree or `None` is the leaf is\n            not a numpy-like array.\n        \"\"\"\n\n        return tuple(\n            map(\n                lambda leaf: getattr(leaf, \"dtype\", None),\n                jax.tree.leaves(tree),\n            )\n        )\n\n    @staticmethod\n    def get_leaf_weak_types(tree: jtp.PyTree) -> tuple[bool, ...]:\n        \"\"\"\n        Get the leaf weak types of a PyTree.\n\n        Args:\n            tree: The PyTree to consider.\n\n        Returns:\n            A tuple marking whether the leaf contains a JAX array with weak type.\n        \"\"\"\n\n        return tuple(\n            map(\n                lambda leaf: getattr(leaf, \"weak_type\", None),\n                jax.tree.leaves(tree),\n            )\n        )\n\n    @staticmethod\n    def check_compatibility(*trees: Sequence[Any]) -> None:\n        \"\"\"\n        Check whether the PyTrees are compatible in structure, shape, and dtype.\n\n        Args:\n            *trees: The PyTrees to compare.\n\n        Raises:\n            ValueError: If the PyTrees have incompatible structures, shapes, or dtypes.\n        \"\"\"\n\n        target_structure = jax.tree.structure(trees[0])\n\n        compatible_structure = functools.reduce(\n            lambda compatible, tree: compatible\n            and jax.tree.structure(tree) == target_structure,\n            trees[1:],\n            True,\n        )\n\n        if not compatible_structure:\n            raise ValueError(\n                f\"Pytrees have incompatible structures.\\n\"\n                f\"Original: {', '.join(map(str, [jax.tree.structure(tree) for tree in trees[1:]]))}\\n\"\n                f\"Target: {target_structure}\"\n            )\n\n        target_shapes = JaxsimDataclass.get_leaf_shapes(trees[0])\n\n        compatible_shapes = functools.reduce(\n            lambda compatible, tree: compatible\n            and JaxsimDataclass.get_leaf_shapes(tree) == target_shapes,\n            trees[1:],\n            True,\n        )\n\n        if not compatible_shapes:\n            raise ValueError(\"Pytrees have incompatible shapes.\")\n\n        target_dtypes = JaxsimDataclass.get_leaf_dtypes(trees[0])\n\n        compatible_dtypes = functools.reduce(\n            lambda compatible, tree: compatible\n            and JaxsimDataclass.get_leaf_dtypes(tree) == target_dtypes,\n            trees[1:],\n            True,\n        )\n\n        if not compatible_dtypes:\n            raise ValueError(\"Pytrees have incompatible dtypes.\")\n\n    def is_mutable(self, validate: bool = False) -> bool:\n        \"\"\"\n        Check whether the object is mutable.\n\n        Args:\n            validate: Additionally checks if the object also has validation enabled.\n\n        Returns:\n            True if the object is mutable, False otherwise.\n        \"\"\"\n\n        return (\n            self.__mutability__ is Mutability.MUTABLE\n            if validate\n            else self.__mutability__ is Mutability.MUTABLE_NO_VALIDATION\n        )\n\n    def mutability(self) -> Mutability:\n        \"\"\"\n        Get the mutability type of the object.\n\n        Returns:\n            The mutability type of the object.\n        \"\"\"\n\n        return self.__mutability__\n\n    def set_mutability(self, mutability: Mutability) -> None:\n        \"\"\"\n        Set the mutability of the object in-place.\n\n        Args:\n            mutability: The desired mutability type.\n        \"\"\"\n\n        jax_dataclasses._copy_and_mutate._mark_mutable(\n            self, mutable=mutability, visited=set()\n        )\n\n    def mutable(self: Self, mutable: bool = True, validate: bool = False) -> Self:\n        \"\"\"\n        Return a mutable reference of the object.\n\n        Args:\n            mutable: Whether to make the object mutable.\n            validate: Whether to enable validation on the object.\n\n        Returns:\n            A mutable reference of the object.\n        \"\"\"\n\n        if mutable:\n            mutability = (\n                Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION\n            )\n        else:\n            mutability = Mutability.FROZEN\n\n        self.set_mutability(mutability=mutability)\n        return self\n\n    def copy(self: Self) -> Self:\n        \"\"\"\n        Return a copy of the object.\n\n        Returns:\n            A copy of the object.\n        \"\"\"\n\n        # Make a copy calling tree_map.\n        obj = jax.tree.map(lambda leaf: leaf, self)\n\n        # Make sure that the copied object and all the copied leaves have the same\n        # mutability of the original object.\n        obj.set_mutability(mutability=self.mutability())\n\n        return obj\n\n    def replace(self: Self, validate: bool = True, **kwargs) -> Self:\n        \"\"\"\n        Return a new object replacing in-place the specified fields with new values.\n\n        Args:\n            validate: Whether to validate that the new fields do not alter the PyTree.\n            **kwargs: The fields to replace.\n\n        Returns:\n            A reference of the object with the specified fields replaced.\n        \"\"\"\n\n        # Use the dataclasses replace method.\n        obj = dataclasses.replace(self, **kwargs)\n\n        if validate:\n            JaxsimDataclass.check_compatibility(self, obj)\n\n        # Make sure that all the new leaves have the same mutability of the object.\n        obj.set_mutability(mutability=self.mutability())\n\n        return obj\n\n    def flatten(self) -> jtp.Vector:\n        \"\"\"\n        Flatten the object into a 1D vector.\n\n        Returns:\n            A 1D vector containing the flattened object.\n        \"\"\"\n\n        return self.flatten_fn()(self)\n\n    @classmethod\n    def flatten_fn(cls: type[Self]) -> Callable[[Self], jtp.Vector]:\n        \"\"\"\n        Return a function to flatten the object into a 1D vector.\n\n        Returns:\n            A function to flatten the object into a 1D vector.\n        \"\"\"\n\n        return lambda pytree: jax.flatten_util.ravel_pytree(pytree)[0]\n\n    def unflatten_fn(self: Self) -> Callable[[jtp.Vector], Self]:\n        \"\"\"\n        Return a function to unflatten a 1D vector into the object.\n\n        Returns:\n            A function to unflatten a 1D vector into the object.\n\n        Notes:\n            Due to JAX internals, the function to unflatten a PyTree needs to be\n            created from an existing instance of the PyTree.\n        \"\"\"\n        return jax.flatten_util.ravel_pytree(self)[1]\n"
  },
  {
    "path": "src/jaxsim/utils/tracing.py",
    "content": "from typing import Any\n\nimport jax._src.core\nimport jax.flatten_util\nimport jax.interpreters.partial_eval\n\n\ndef tracing(var: Any) -> bool | jax.Array:\n    \"\"\"Return True if the variable is being traced by JAX, False otherwise.\"\"\"\n\n    return isinstance(\n        var, jax._src.core.Tracer | jax.interpreters.partial_eval.DynamicJaxprTracer\n    )\n\n\ndef not_tracing(var: Any) -> bool | jax.Array:\n    \"\"\"Return True if the variable is not being traced by JAX, False otherwise.\"\"\"\n\n    return True if tracing(var) is False else False\n"
  },
  {
    "path": "src/jaxsim/utils/wrappers.py",
    "content": "from __future__ import annotations\n\nimport dataclasses\nfrom collections.abc import Callable\nfrom typing import Generic, TypeVar\n\nimport jax\nimport jax_dataclasses\nimport numpy as np\nimport numpy.typing as npt\n\nT = TypeVar(\"T\")\n\n\n@dataclasses.dataclass\nclass HashlessObject(Generic[T]):\n    \"\"\"\n    A class that wraps an object and makes it hashless.\n\n    This is useful for creating particular JAX pytrees.\n    For example, to create a pytree with a static leaf that is ignored\n    by JAX when it compares two instances to trigger a JIT recompilation.\n    \"\"\"\n\n    obj: T\n\n    def get(self: HashlessObject[T]) -> T:\n        \"\"\"\n        Get the wrapped object.\n        \"\"\"\n        return self.obj\n\n    def __hash__(self) -> int:\n\n        return 0\n\n    def __eq__(self, other: HashlessObject[T]) -> bool:\n\n        if not isinstance(other, HashlessObject) and isinstance(\n            other.get(), type(self.get())\n        ):\n            return False\n\n        return hash(self) == hash(other)\n\n\n@dataclasses.dataclass\nclass CustomHashedObject(Generic[T]):\n    \"\"\"\n    A class that wraps an object and computes its hash with a custom hash function.\n    \"\"\"\n\n    obj: T\n\n    hash_function: Callable[[T], int] = hash\n\n    def get(self: CustomHashedObject[T]) -> T:\n        \"\"\"\n        Get the wrapped object.\n        \"\"\"\n        return self.obj\n\n    def __hash__(self) -> int:\n\n        return self.hash_function(self.obj)\n\n    def __eq__(self, other: CustomHashedObject[T]) -> bool:\n\n        if not isinstance(other, CustomHashedObject) and isinstance(\n            other.get(), type(self.get())\n        ):\n            return False\n\n        return hash(self) == hash(other)\n\n\n@jax_dataclasses.pytree_dataclass\nclass HashedNumpyArray:\n    \"\"\"\n    A class that wraps a numpy array and makes it hashable.\n\n    This is useful for creating particular JAX pytrees.\n    For example, to create a pytree with a plain NumPy or JAX NumPy array as static leaf.\n\n    Note:\n        Calculating with the wrapper class the hash of a very large array can be\n        very expensive. If the array is large and only the equality operator is needed,\n        set `large_array=True` to use a faster comparison method.\n    \"\"\"\n\n    array: jax.Array | npt.NDArray\n\n    precision: float | None = dataclasses.field(\n        default=1e-9, repr=False, compare=False, hash=False\n    )\n\n    large_array: jax_dataclasses.Static[bool] = dataclasses.field(\n        default=False, repr=False, compare=False, hash=False\n    )\n\n    def get(self) -> jax.Array | npt.NDArray:\n        \"\"\"\n        Get the wrapped array.\n        \"\"\"\n        return self.array\n\n    def __hash__(self) -> int:\n\n        return HashedNumpyArray.hash_of_array(\n            array=self.array, precision=self.precision\n        )\n\n    def __eq__(self, other: HashedNumpyArray) -> bool:\n\n        if not isinstance(other, HashedNumpyArray):\n            return False\n\n        if self.large_array:\n            return np.allclose(\n                self.array,\n                other.array,\n                **(dict(atol=self.precision) if self.precision is not None else {}),\n            )\n\n        return hash(self) == hash(other)\n\n    @staticmethod\n    def hash_of_array(\n        array: jax.Array | npt.NDArray, precision: float | None = 1e-9\n    ) -> int:\n        \"\"\"\n        Calculate the hash of a NumPy array.\n\n        Args:\n            array: The array to hash.\n            precision: Optionally limit the precision over which the hash is computed.\n\n        Returns:\n            The hash of the array.\n        \"\"\"\n\n        array = np.array(array).flatten()\n\n        array = np.where(array == np.nan, hash(np.nan), array)\n        array = np.where(array == np.inf, hash(np.inf), array)\n        array = np.where(array == -np.inf, hash(-np.inf), array)\n\n        if precision is not None:\n\n            integer1 = (array * precision).astype(int)\n            integer2 = (array - integer1 / precision).astype(int)\n\n            decimal_array = ((array - integer1 * 1e9 - integer2) / precision).astype(\n                int\n            )\n\n            array = np.hstack([integer1, integer2, decimal_array]).astype(int)\n\n        return hash(tuple(array.tolist()))\n"
  },
  {
    "path": "tests/__init__.py",
    "content": ""
  },
  {
    "path": "tests/assets/4_bar_opened.urdf",
    "content": "<robot name=\"4_bar_opened\">\n    <!-- Link AB -->\n    <link name=\"AB\">\n        <inertial>\n            <mass value=\"1.0\" />\n            <inertia ixx=\"0.02167\" iyy=\"0.00167\" izz=\"0.02167\" ixy=\"0.0\" ixz=\"0.0\" iyz=\"0.0\" />\n        </inertial>\n        <visual>\n            <geometry>\n                <box size=\"0.1 0.5 0.1\" />\n            </geometry>\n        </visual>\n    </link>\n\n    <!-- Link BC1 -->\n    <link name=\"BC1\">\n        <inertial>\n            <origin xyz=\"0 0.125 0\" rpy=\"0 0 0\" />\n            <mass value=\"0.5\" />\n            <inertia ixx=\"0.010835\" iyy=\"0.000835\" izz=\"0.010835\" ixy=\"0.0\" ixz=\"0.0\" iyz=\"0.0\" />\n        </inertial>\n        <visual>\n            <origin xyz=\"0 0.125 0\" rpy=\"0 0 0\" />\n            <geometry>\n                <box size=\"0.1 0.25 0.1\" />\n            </geometry>\n        </visual>\n    </link>\n\n    <!-- Frame at the end of BC1 -->\n    <link name=\"BC1_frame\" />\n\n    <joint name=\"BC1_frame_joint\" type=\"fixed\">\n        <parent link=\"BC1\" />\n        <child link=\"BC1_frame\" />\n        <origin xyz=\"0 0.25 0\" rpy=\"0 0 0\" />\n    </joint>\n\n    <!-- Link BC2 -->\n    <link name=\"BC2\">\n        <inertial>\n            <origin xyz=\"0 0.125 0\" rpy=\"0 0 0\" />\n            <mass value=\"0.5\" />\n            <inertia ixx=\"0.010835\" iyy=\"0.000835\" izz=\"0.010835\" ixy=\"0.0\" ixz=\"0.0\" iyz=\"0.0\" />\n        </inertial>\n        <visual>\n            <origin xyz=\"0 0.125 0\" rpy=\"0 0 0\" />\n            <geometry>\n                <box size=\"0.1 0.25 0.1\" /> <!-- Dummy frame link -->\n            </geometry>\n        </visual>\n    </link>\n\n    <!-- Frame at the end of BC2 -->\n    <link name=\"BC2_frame\" />\n\n    <joint name=\"BC2_frame_joint\" type=\"fixed\">\n        <parent link=\"BC2\" />\n        <child link=\"BC2_frame\" />\n        <origin xyz=\"0 0.25 0\" rpy=\"0 0 3.1416\" />\n    </joint>\n\n    <!-- Link CD -->\n    <link name=\"CD\">\n        <inertial>\n            <origin xyz=\"0 0.25 0\" rpy=\"0 0 0\" />\n            <mass value=\"1.0\" />\n            <inertia ixx=\"0.02167\" iyy=\"0.00167\" izz=\"0.02167\" ixy=\"0.0\" ixz=\"0.0\" iyz=\"0.0\" />\n        </inertial>\n        <visual>\n            <origin xyz=\"0 0.25 0\" rpy=\"0 0 0\" />\n            <geometry>\n                <box size=\"0.1 0.5 0.1\" />\n            </geometry>\n            <material name=\"Cyan\">\n                <color rgba=\"0 1.0 1.0 1.0\" />\n            </material>\n        </visual>\n        <collision>\n            <origin xyz=\"0 0.25 0\" rpy=\"0 0 0\" />\n            <geometry>\n                <box size=\"0.1 0.5 0.1\" />\n            </geometry>\n        </collision>\n    </link>\n\n    <!-- Link DA -->\n    <link name=\"DA\">\n        <inertial>\n            <origin xyz=\"0 0.25 0\" rpy=\"0 0 0\" />\n            <mass value=\"1.0\" />\n            <inertia ixx=\"0.02167\" iyy=\"0.00167\" izz=\"0.02167\" ixy=\"0.0\" ixz=\"0.0\" iyz=\"0.0\" />\n        </inertial>\n        <visual>\n            <origin xyz=\"0 0.25 0\" rpy=\"0 0 0\" />\n            <geometry>\n                <box size=\"0.1 0.5 0.1\" />\n            </geometry>\n        </visual>\n    </link>\n\n    <!-- Joint B -->\n    <joint name=\"B\" type=\"revolute\">\n        <parent link=\"AB\" />\n        <child link=\"BC1\" />\n        <origin xyz=\"0 -0.25 0\" rpy=\"0 0 1.57\" />\n        <axis xyz=\"0 0 1\" />\n        <limit lower=\"-1.57\" upper=\"1.57\" effort=\"30\" velocity=\"1.0\" />\n    </joint>\n\n    <!-- Joint C -->\n    <joint name=\"C\" type=\"revolute\">\n        <parent link=\"CD\" />\n        <child link=\"BC2\" />\n        <origin xyz=\"0 0.5 0\" rpy=\"0 0 1.57\" />\n        <axis xyz=\"0 0 1\" />\n        <limit lower=\"-1.57\" upper=\"1.57\" effort=\"30\" velocity=\"1.0\" />\n    </joint>\n\n    <!-- Joint D -->\n    <joint name=\"D\" type=\"revolute\">\n        <parent link=\"DA\" />\n        <child link=\"CD\" />\n        <origin xyz=\"0 0.5 0\" rpy=\"0 0 1.57\" />\n        <axis xyz=\"0 0 1\" />\n        <limit lower=\"-1.57\" upper=\"1.57\" effort=\"30\" velocity=\"1.0\" />\n    </joint>\n\n    <!-- Joint A -->\n    <joint name=\"A\" type=\"revolute\">\n        <parent link=\"AB\" />\n        <child link=\"DA\" />\n        <origin xyz=\"0 0.25 0\" rpy=\"0 0 1.57\" />\n        <axis xyz=\"0 0 1\" />\n        <limit lower=\"-1.57\" upper=\"1.57\" effort=\"30\" velocity=\"1.0\" />\n    </joint>\n</robot>\n"
  },
  {
    "path": "tests/assets/cube.stl",
    "content": "solid model\r\nfacet normal 0.0 0.0 -1.0\r\nouter loop\r\nvertex 20.0 0.0 0.0\r\nvertex 0.0 -20.0 0.0\r\nvertex 0.0 0.0 0.0\r\nendloop\r\nendfacet\r\nfacet normal 0.0 0.0 -1.0\r\nouter loop\r\nvertex 0.0 -20.0 0.0\r\nvertex 20.0 0.0 0.0\r\nvertex 20.0 -20.0 0.0\r\nendloop\r\nendfacet\r\nfacet normal -0.0 -1.0 -0.0\r\nouter loop\r\nvertex 20.0 -20.0 20.0\r\nvertex 0.0 -20.0 0.0\r\nvertex 20.0 -20.0 0.0\r\nendloop\r\nendfacet\r\nfacet normal -0.0 -1.0 -0.0\r\nouter loop\r\nvertex 0.0 -20.0 0.0\r\nvertex 20.0 -20.0 20.0\r\nvertex 0.0 -20.0 20.0\r\nendloop\r\nendfacet\r\nfacet normal 1.0 0.0 0.0\r\nouter loop\r\nvertex 20.0 0.0 0.0\r\nvertex 20.0 -20.0 20.0\r\nvertex 20.0 -20.0 0.0\r\nendloop\r\nendfacet\r\nfacet normal 1.0 0.0 0.0\r\nouter loop\r\nvertex 20.0 -20.0 20.0\r\nvertex 20.0 0.0 0.0\r\nvertex 20.0 0.0 20.0\r\nendloop\r\nendfacet\r\nfacet normal -0.0 -0.0 1.0\r\nouter loop\r\nvertex 20.0 -20.0 20.0\r\nvertex 0.0 0.0 20.0\r\nvertex 0.0 -20.0 20.0\r\nendloop\r\nendfacet\r\nfacet normal -0.0 -0.0 1.0\r\nouter loop\r\nvertex 0.0 0.0 20.0\r\nvertex 20.0 -20.0 20.0\r\nvertex 20.0 0.0 20.0\r\nendloop\r\nendfacet\r\nfacet normal -1.0 0.0 0.0\r\nouter loop\r\nvertex 0.0 0.0 20.0\r\nvertex 0.0 -20.0 0.0\r\nvertex 0.0 -20.0 20.0\r\nendloop\r\nendfacet\r\nfacet normal -1.0 0.0 0.0\r\nouter loop\r\nvertex 0.0 -20.0 0.0\r\nvertex 0.0 0.0 20.0\r\nvertex 0.0 0.0 0.0\r\nendloop\r\nendfacet\r\nfacet normal -0.0 1.0 0.0\r\nouter loop\r\nvertex 0.0 0.0 20.0\r\nvertex 20.0 0.0 0.0\r\nvertex 0.0 0.0 0.0\r\nendloop\r\nendfacet\r\nfacet normal -0.0 1.0 0.0\r\nouter loop\r\nvertex 20.0 0.0 0.0\r\nvertex 0.0 0.0 20.0\r\nvertex 20.0 0.0 20.0\r\nendloop\r\nendfacet\r\nendsolid model\r\n"
  },
  {
    "path": "tests/assets/double_pendulum.sdf",
    "content": "<?xml version=\"1.0\"?>\n\n<sdf version=\"1.7\">\n\n    <model name=\"double_pendulum\">\n        <!-- <pose>0 0 0.2 0 0 0</pose> -->\n        <joint name=\"fixed_base\" type=\"fixed\">\n            <parent>world</parent>\n            <child>base_link</child>\n            <axis>\n                <xyz>1 0 0</xyz>\n                <limit>\n                    <lower>-5</lower>\n                    <upper>5</upper>\n                    <effort>100</effort>\n                    <velocity>100</velocity>\n                </limit>\n                <dynamics>\n                    <damping>0.0</damping>\n                    <spring_reference>0</spring_reference>\n                    <spring_stiffness>0.0</spring_stiffness>\n                </dynamics>\n            </axis>\n        </joint>\n        <link name='base_link'>\n            <inertial>\n                <pose>0 0 0 0 0 0</pose>\n                <mass>100</mass>\n                <inertia>\n                    <ixx>1</ixx>\n                    <ixy>0</ixy>\n                    <ixz>0</ixz>\n                    <iyy>1</iyy>\n                    <iyz>0</iyz>\n                    <izz>1</izz>\n                </inertia>\n            </inertial>\n            <collision name='base_link_collision'>\n                <pose>0 0 1 0 0 0</pose>\n                <geometry>\n                    <box>\n                        <size>0.20 0.20 2.15</size>\n                    </box>\n                </geometry>\n            </collision>\n            <visual name='base_link_visual'>\n                <pose>0 0 1 0 0 0</pose>\n                <geometry>\n                    <box>\n                        <size>0.20 0.20 2.15</size>\n                    </box>\n                </geometry>\n            </visual>\n        </link>\n        <joint name='right_joint' type='revolute'>\n            <pose relative_to='base_link'>0.20 0 2 -3.1415 0 0</pose>\n            <parent>base_link</parent>\n            <child>right_link</child>\n            <axis>\n                <xyz>1 0 0</xyz>\n                <limit>\n                    <lower>-100</lower>\n                    <upper>100</upper>\n                    <effort>100</effort>\n                    <velocity>100</velocity>\n                </limit>\n                <dynamics>\n                    <damping>1.0</damping>\n                    <spring_reference>0</spring_reference>\n                    <spring_stiffness>0.0</spring_stiffness>\n                </dynamics>\n            </axis>\n        </joint>\n        <link name='right_link'>\n            <pose relative_to='right_joint'>0 0 0 0 0 0</pose>\n            <self_collide>0</self_collide>\n            <inertial>\n                <pose>0 0 0.5 0 0 0</pose>\n                <mass>1</mass>\n                <inertia>\n                    <ixx>1.0</ixx>\n                    <ixy>0</ixy>\n                    <ixz>0</ixz>\n                    <iyy>1.0</iyy>\n                    <iyz>0</iyz>\n                    <izz>1.0</izz>\n                </inertia>\n            </inertial>\n            <!-- <collision name='right_link_collision'>\n                <pose>0 0 0.5 0 0 0</pose>\n                <geometry>\n                    <box>\n                        <size>0.20 0.20 1.0</size>\n                    </box>\n                </geometry>\n            </collision> -->\n            <visual name='right_link_visual'>\n                <pose>0 0 0.5 0 0 0</pose>\n                <geometry>\n                    <box>\n                        <size>0.20 0.20 1.0</size>\n                    </box>\n                </geometry>\n            </visual>\n        </link>\n\n        <joint name='left_joint' type='revolute'>\n            <pose relative_to='base_link'>-0.20 0 2 -3.1415 0 0</pose>\n            <parent>base_link</parent>\n            <child>left_link</child>\n            <axis>\n                <xyz>1 0 0</xyz>\n                <limit>\n                    <lower>-100</lower>\n                    <upper>100</upper>\n                    <effort>100</effort>\n                    <velocity>100</velocity>\n                </limit>\n                <dynamics>\n                    <damping>1.0</damping>\n                    <spring_reference>0</spring_reference>\n                    <spring_stiffness>0.0</spring_stiffness>\n                </dynamics>\n            </axis>\n        </joint>\n        <link name='left_link'>\n            <pose relative_to='left_joint'>0 0 0 0 0 0</pose>\n            <self_collide>0</self_collide>\n            <inertial>\n                <pose>0.0 0 0.5 0 0 0</pose>\n                <mass>1</mass>\n                <inertia>\n                    <ixx>1.0</ixx>\n                    <ixy>0</ixy>\n                    <ixz>0</ixz>\n                    <iyy>1.0</iyy>\n                    <iyz>0</iyz>\n                    <izz>1.0</izz>\n                </inertia>\n            </inertial>\n            <!-- <collision name='left_link_collision'>\n                <pose>0.0 0 0.5 0 0 0</pose>\n                <geometry>\n                    <box>\n                        <size>0.20 0.20 1.0</size>\n                    </box>\n                </geometry>\n            </collision> -->\n            <visual name='left_link_visual'>\n                <pose>0.0 0 0.5 0 0 0</pose>\n                <geometry>\n                    <box>\n                        <size>0.20 0.20 1.0</size>\n                    </box>\n                </geometry>\n            </visual>\n        </link>\n        <frame name=\"base_link_middle_right_frame\" attached_to=\"base_link\">\n            <pose relative_to=\"base_link\">0.20 0 1 0 0 0</pose>\n        </frame>\n        <frame name=\"base_link_middle_left_frame\" attached_to=\"base_link\">\n            <pose relative_to=\"base_link\">-0.20 0 1 0 0 0</pose>\n        </frame>\n        <frame name=\"right_link_extremity_frame\" attached_to=\"right_link\">\n            <pose relative_to=\"right_link\"> -0.2 0 1 3.14 0 0</pose>\n        </frame>\n        <frame name=\"left_link_extremity_frame\" attached_to=\"left_link\">\n            <pose relative_to=\"left_link\"> 0.2 0 1 3.14 0 0</pose>\n        </frame>\n    </model>\n\n</sdf>\n"
  },
  {
    "path": "tests/assets/mixed_shapes_robot.urdf",
    "content": "<?xml version=\"1.0\"?>\n<robot name=\"mixed_shapes_robot\">\n\n  <!-- Link 1: Box primitive -->\n  <link name=\"box_link\">\n    <inertial>\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\n      <mass value=\"1.0\"/>\n      <inertia ixx=\"0.01\" ixy=\"0.0\" ixz=\"0.0\"\n               iyy=\"0.01\" iyz=\"0.0\"\n               izz=\"0.01\"/>\n    </inertial>\n\n    <visual>\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\n      <geometry>\n        <box size=\"0.1 0.1 0.1\"/>\n      </geometry>\n      <material name=\"red\">\n        <color rgba=\"0.8 0.2 0.2 1.0\"/>\n      </material>\n    </visual>\n  </link>\n\n  <!-- Link 2: Cylinder primitive -->\n  <link name=\"cylinder_link\">\n    <inertial>\n      <origin xyz=\"0 0 0.05\" rpy=\"0 0 0\"/>\n      <mass value=\"0.5\"/>\n      <inertia ixx=\"0.005\" ixy=\"0.0\" ixz=\"0.0\"\n               iyy=\"0.005\" iyz=\"0.0\"\n               izz=\"0.002\"/>\n    </inertial>\n\n    <visual>\n      <origin xyz=\"0 0 0.05\" rpy=\"0 0 0\"/>\n      <geometry>\n        <cylinder radius=\"0.03\" length=\"0.1\"/>\n      </geometry>\n      <material name=\"green\">\n        <color rgba=\"0.2 0.8 0.2 1.0\"/>\n      </material>\n    </visual>\n  </link>\n\n  <!-- Link 3: Mesh -->\n  <link name=\"mesh_link\">\n    <inertial>\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\n      <mass value=\"0.008\"/>\n      <inertia ixx=\"5.333e-7\" ixy=\"0.0\" ixz=\"0.0\"\n               iyy=\"5.333e-7\" iyz=\"0.0\"\n               izz=\"5.333e-7\"/>\n    </inertial>\n\n    <visual>\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\n      <geometry>\n        <mesh filename=\"tests/assets/cube.stl\" scale=\"0.001 0.001 0.001\"/>\n      </geometry>\n      <material name=\"blue\">\n        <color rgba=\"0.2 0.2 0.8 1.0\"/>\n      </material>\n    </visual>\n  </link>\n\n  <!-- Link 4: Sphere primitive -->\n  <link name=\"sphere_link\">\n    <inertial>\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\n      <mass value=\"0.3\"/>\n      <inertia ixx=\"0.0012\" ixy=\"0.0\" ixz=\"0.0\"\n               iyy=\"0.0012\" iyz=\"0.0\"\n               izz=\"0.0012\"/>\n    </inertial>\n\n    <visual>\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\n      <geometry>\n        <sphere radius=\"0.05\"/>\n      </geometry>\n      <material name=\"yellow\">\n        <color rgba=\"0.8 0.8 0.2 1.0\"/>\n      </material>\n    </visual>\n  </link>\n\n  <joint name=\"box_to_cylinder\" type=\"revolute\">\n    <parent link=\"box_link\"/>\n    <child link=\"cylinder_link\"/>\n    <origin xyz=\"0 0 0.1\" rpy=\"0 0 0\"/>\n    <axis xyz=\"0 0 1\"/>\n    <limit lower=\"-1.57\" upper=\"1.57\" effort=\"10\" velocity=\"1.0\"/>\n  </joint>\n\n  <joint name=\"cylinder_to_mesh\" type=\"revolute\">\n    <parent link=\"cylinder_link\"/>\n    <child link=\"mesh_link\"/>\n    <origin xyz=\"0 0 0.15\" rpy=\"0 0 0\"/>\n    <axis xyz=\"1 0 0\"/>\n    <limit lower=\"-1.57\" upper=\"1.57\" effort=\"10\" velocity=\"1.0\"/>\n  </joint>\n\n  <joint name=\"mesh_to_sphere\" type=\"revolute\">\n    <parent link=\"mesh_link\"/>\n    <child link=\"sphere_link\"/>\n    <origin xyz=\"0 0 0.05\" rpy=\"0 0 0\"/>\n    <axis xyz=\"0 1 0\"/>\n    <limit lower=\"-1.57\" upper=\"1.57\" effort=\"10\" velocity=\"1.0\"/>\n  </joint>\n\n</robot>\n"
  },
  {
    "path": "tests/assets/test_cube.urdf",
    "content": "<?xml version=\"1.0\"?>\n<robot name=\"test_cube\">\n\n  <!-- Single cube link with mesh visual -->\n  <link name=\"cube_link\">\n    <inertial>\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\n      <!-- For a 20mm (0.02m) cube with density 1000 kg/m3:\n           Volume = 0.02 **3 = 8e-6 m3\n           Mass = 1000 * 8e-6 = 0.008 kg = 8 grams -->\n      <mass value=\"0.008\"/>\n      <!-- For a cube: Ixx = Iyy = Izz = (1/6)*m*L2\n           I = (1/6) * 0.008 * 0.02 ** 2 = 5.333e-7 kgm2 -->\n      <inertia ixx=\"5.333e-7\" ixy=\"0.0\" ixz=\"0.0\"\n               iyy=\"5.333e-7\" iyz=\"0.0\"\n               izz=\"5.333e-7\"/>\n    </inertial>\n\n    <visual>\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\n      <geometry>\n        <mesh filename=\"tests/assets/cube.stl\" scale=\"0.001 0.001 0.001\"/>\n      </geometry>\n      <material name=\"blue\">\n        <color rgba=\"0.2 0.2 0.8 1.0\"/>\n      </material>\n    </visual>\n\n    <collision>\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\n      <geometry>\n        <mesh filename=\"tests/assets/cube.stl\" scale=\"0.001 0.001 0.001\"/>\n      </geometry>\n    </collision>\n  </link>\n\n</robot>\n"
  },
  {
    "path": "tests/conftest.py",
    "content": "import os\n\nos.environ[\"JAXSIM_ENABLE_EXCEPTIONS\"] = \"1\"\n\nimport pathlib\nimport subprocess\n\nimport jax\nimport numpy as np\nimport pytest\nimport rod\nimport rod.urdf.exporter\n\nimport jaxsim\nimport jaxsim.api as js\nfrom jaxsim.api.model import IntegratorType\n\n\ndef pytest_addoption(parser):\n    parser.addoption(\n        \"--gpu-only\",\n        action=\"store_true\",\n        default=False,\n        help=\"Run tests only if GPU is available and utilized\",\n    )\n\n    parser.addoption(\n        \"--batch-size\",\n        action=\"store\",\n        default=\"None\",\n        help=\"Batch size for vectorized benchmarks (only applies to benchmark tests)\",\n    )\n\n\ndef pytest_generate_tests(metafunc):\n    if (\n        \"batch_size\" in metafunc.fixturenames\n        and (batch_size := metafunc.config.getoption(\"--batch-size\")) != \"None\"\n    ):\n        metafunc.parametrize(\"batch_size\", [1, int(batch_size)])\n\n\ndef check_gpu_usage():\n    # Set environment variable to prioritize GPU.\n    os.environ[\"JAX_PLATFORM_NAME\"] = \"gpu\"\n\n    # Run a simple JAX operation\n    x = jax.device_put(jax.numpy.ones((512, 512)))\n    y = jax.device_put(jax.numpy.ones((512, 512)))\n    _ = jax.numpy.dot(x, y).block_until_ready()\n\n    # Check GPU memory usage with nvidia-smi.\n    result = subprocess.run(\n        [\"nvidia-smi\", \"--query-gpu=memory.used\", \"--format=csv,noheader\"],\n        capture_output=True,\n        text=True,\n    )\n    if result.returncode != 0:\n        pytest.exit(\n            \"Failed to query GPU usage. Ensure nvidia-smi is installed and accessible.\"\n        )\n\n    gpu_memory_usage = [\n        int(line.strip().split()[0]) for line in result.stdout.splitlines()\n    ]\n    if all(usage == 0 for usage in gpu_memory_usage):\n        pytest.exit(\n            \"GPU is available but not utilized during computations. Check your JAX installation.\"\n        )\n\n\ndef pytest_configure(config) -> None:\n    \"\"\"Pytest configuration hook.\"\"\"\n\n    # This is a global variable that is updated by the `prng_key` fixture.\n    pytest.prng_key = jax.random.PRNGKey(\n        seed=int(os.environ.get(\"JAXSIM_TEST_SEED\", 0))\n    )\n\n    # Check if GPU is available and utilized.\n    if config.getoption(\"--gpu-only\"):\n        devices = jax.devices()\n        if not any(device.platform == \"gpu\" for device in devices):\n            pytest.exit(\"No GPU devices found. Check your JAX installation.\")\n\n        # Ensure GPU is being used during computation\n        check_gpu_usage()\n\n\ndef load_model_from_file(file_path: pathlib.Path, is_urdf=False) -> rod.Sdf:\n    \"\"\"\n    Load an SDF or URDF model from a file.\n\n    Args:\n        file_path: The path to the model file.\n        is_urdf: Whether the file is in URDF or SDF format.\n\n    Returns:\n        The corresponding rod model.\n    \"\"\"\n\n    return rod.Sdf.load(file_path, is_urdf=is_urdf)\n\n\n# ================\n# Generic fixtures\n# ================\n\n\n@pytest.fixture(scope=\"function\")\ndef prng_key() -> jax.Array:\n    \"\"\"\n    Fixture to generate a new PRNG key for each test function.\n\n    Returns:\n        The new PRNG key passed to the test.\n\n    Note:\n        This fixture operates on a global variable initialized in the\n        `pytest_configure` hook.\n    \"\"\"\n\n    pytest.prng_key, subkey = jax.random.split(pytest.prng_key, num=2)\n    return subkey\n\n\n@pytest.fixture(\n    scope=\"function\",\n    params=[\n        pytest.param(jaxsim.VelRepr.Inertial, id=\"inertial\"),\n        pytest.param(jaxsim.VelRepr.Body, id=\"body\"),\n        pytest.param(jaxsim.VelRepr.Mixed, id=\"mixed\"),\n    ],\n)\ndef velocity_representation(request) -> jaxsim.VelRepr:\n    \"\"\"\n    Parametrized fixture providing all supported velocity representations.\n\n    Returns:\n        A velocity representation.\n    \"\"\"\n\n    return request.param\n\n\n@pytest.fixture(\n    scope=\"function\",\n    params=[\n        pytest.param(IntegratorType.SemiImplicitEuler, id=\"semi_implicit_euler\"),\n        pytest.param(IntegratorType.RungeKutta4, id=\"runge_kutta_4\"),\n        pytest.param(IntegratorType.RungeKutta4Fast, id=\"runge_kutta_4_fast\"),\n    ],\n)\ndef integrator(request) -> str:\n    \"\"\"\n    Fixture providing the integrator to use in the simulation.\n\n    Returns:\n        The integrator to use in the simulation.\n    \"\"\"\n\n    return request.param\n\n\n@pytest.fixture(scope=\"session\")\ndef batch_size(request) -> int:\n    \"\"\"\n    Fixture providing the batch size for vectorized benchmarks.\n\n    Returns:\n        The batch size for vectorized benchmarks.\n    \"\"\"\n\n    return 1\n\n\n# ================================\n# Fixtures providing JaxSim models\n# ================================\n\n# All the fixtures in this section must have \"session\" scope.\n# In this way, the models are generated only once and shared among all the tests.\n\n\n# This is not a fixture.\ndef build_jaxsim_model(\n    model_description: str | pathlib.Path | rod.Model,\n) -> js.model.JaxSimModel:\n    \"\"\"\n    Build a JaxSim model from a model description.\n\n    Args:\n        model_description: A model description provided by any fixture provider.\n\n    Returns:\n        A JaxSim model built from the provided description.\n    \"\"\"\n\n    # Build the JaxSim model.\n    model = js.model.JaxSimModel.build_from_model_description(\n        model_description=model_description,\n    )\n\n    return model\n\n\n@pytest.fixture(scope=\"session\")\ndef jaxsim_model_box() -> js.model.JaxSimModel:\n    \"\"\"\n    Fixture providing the JaxSim model of a box.\n\n    Returns:\n        The JaxSim model of a box.\n    \"\"\"\n\n    import rod.builder.primitives\n    import rod.urdf.exporter\n\n    # Create on-the-fly a ROD model of a box.\n    rod_model = (\n        rod.builder.primitives.BoxBuilder(x=0.3, y=0.2, z=0.1, mass=1.0, name=\"box\")\n        .build_model()\n        .add_link(name=\"box_link\")\n        .add_inertial()\n        .add_visual()\n        .add_collision()\n        .build()\n    )\n\n    rod_model.add_frame(\n        rod.Frame(\n            name=\"box_frame\",\n            attached_to=\"box_link\",\n            pose=rod.Pose(relative_to=\"box_link\", pose=[1, 1, 1, 0.5, 0.4, 0.3]),\n        )\n    )\n\n    # Export the URDF string.\n    urdf_string = rod.urdf.exporter.UrdfExporter(\n        pretty=True, gazebo_preserve_fixed_joints=True\n    ).to_urdf_string(sdf=rod_model)\n\n    return build_jaxsim_model(model_description=urdf_string)\n\n\n@pytest.fixture(scope=\"session\")\ndef jaxsim_model_sphere() -> js.model.JaxSimModel:\n    \"\"\"\n    Fixture providing the JaxSim model of a sphere.\n\n    Returns:\n        The JaxSim model of a sphere.\n    \"\"\"\n\n    import rod.builder.primitives\n    import rod.urdf.exporter\n\n    # Create on-the-fly a ROD model of a sphere.\n    rod_model = (\n        rod.builder.primitives.SphereBuilder(radius=0.1, mass=1.0, name=\"sphere\")\n        .build_model()\n        .add_link()\n        .add_inertial()\n        .add_visual()\n        .add_collision()\n        .build()\n    )\n\n    # Export the URDF string.\n    urdf_string = rod.urdf.exporter.UrdfExporter(pretty=True).to_urdf_string(\n        sdf=rod_model\n    )\n\n    return build_jaxsim_model(model_description=urdf_string)\n\n\n@pytest.fixture(scope=\"session\")\ndef ergocub_model_description_path() -> pathlib.Path:\n    \"\"\"\n    Fixture providing the path to the URDF model description of the ErgoCub robot.\n\n    Returns:\n        The path to the URDF model description of the ErgoCub robot.\n\n    \"\"\"\n\n    try:\n        os.environ[\"ROBOT_DESCRIPTION_COMMIT\"] = \"v0.7.7\"\n\n        import robot_descriptions.ergocub_description\n\n    finally:\n        _ = os.environ.pop(\"ROBOT_DESCRIPTION_COMMIT\", None)\n\n    model_urdf_path = pathlib.Path(\n        robot_descriptions.ergocub_description.URDF_PATH.replace(\n            \"ergoCubSN002\", \"ergoCubSN001\"\n        )\n    )\n\n    return model_urdf_path\n\n\n@pytest.fixture(scope=\"session\")\ndef jaxsim_model_ergocub(\n    ergocub_model_description_path: pathlib.Path,\n) -> js.model.JaxSimModel:\n    \"\"\"\n    Fixture providing the JaxSim model of the ErgoCub robot.\n\n    Returns:\n        The JaxSim model of the ErgoCub robot.\n\n    \"\"\"\n\n    return build_jaxsim_model(model_description=ergocub_model_description_path)\n\n\n@pytest.fixture(scope=\"session\")\ndef jaxsim_model_ergocub_reduced(jaxsim_model_ergocub) -> js.model.JaxSimModel:\n    \"\"\"\n    Fixture providing the JaxSim model of the ErgoCub robot with only locomotion joints.\n\n    Returns:\n        The JaxSim model of the ErgoCub robot with only locomotion joints.\n\n    \"\"\"\n\n    model_full = jaxsim_model_ergocub\n\n    # Get the names of the joints to keep.\n    reduced_joints = tuple(\n        j\n        for j in model_full.joint_names()\n        if \"camera\" not in j\n        # Remove head and hands.\n        and \"neck\" not in j\n        and \"wrist\" not in j\n        and \"thumb\" not in j\n        and \"index\" not in j\n        and \"middle\" not in j\n        and \"ring\" not in j\n        and \"pinkie\" not in j\n        # Remove upper body.\n        and \"torso\" not in j and \"elbow\" not in j and \"shoulder\" not in j\n    )\n\n    model = js.model.reduce(model=model_full, considered_joints=reduced_joints)\n\n    return model\n\n\n@pytest.fixture(scope=\"session\")\ndef jaxsim_model_ur10() -> js.model.JaxSimModel:\n    \"\"\"\n    Fixture providing the JaxSim model of the UR10 robot.\n\n    Returns:\n        The JaxSim model of the UR10 robot.\n\n    \"\"\"\n\n    import robot_descriptions.ur10_description\n\n    model_urdf_path = pathlib.Path(robot_descriptions.ur10_description.URDF_PATH)\n\n    return build_jaxsim_model(model_description=model_urdf_path)\n\n\n@pytest.fixture(scope=\"session\")\ndef jaxsim_model_single_pendulum() -> js.model.JaxSimModel:\n    \"\"\"\n    Fixture providing the JaxSim model of a single pendulum.\n\n    Returns:\n        The JaxSim model of a single pendulum.\n    \"\"\"\n\n    import rod.builder.primitives\n\n    base_height = 2.15\n    upper_height = 1.0\n\n    # ===================\n    # Create the builders\n    # ===================\n\n    base_builder = rod.builder.primitives.BoxBuilder(\n        name=\"base\",\n        mass=1.0,\n        x=0.15,\n        y=0.15,\n        z=base_height,\n    )\n\n    upper_builder = rod.builder.primitives.BoxBuilder(\n        name=\"upper\",\n        mass=0.5,\n        x=0.15,\n        y=0.15,\n        z=upper_height,\n    )\n\n    # =================\n    # Create the joints\n    # =================\n\n    fixed = rod.Joint(\n        name=\"fixed_joint\",\n        type=\"fixed\",\n        parent=\"world\",\n        child=base_builder.name,\n    )\n\n    pivot = rod.Joint(\n        name=\"upper_joint\",\n        type=\"continuous\",\n        parent=base_builder.name,\n        child=upper_builder.name,\n        axis=rod.Axis(\n            xyz=rod.Xyz([1, 0, 0]),\n        ),\n    )\n\n    # ================\n    # Create the links\n    # ================\n\n    base = (\n        base_builder.build_link(\n            name=base_builder.name,\n            pose=rod.builder.primitives.PrimitiveBuilder.build_pose(\n                pos=np.array([0, 0, base_height / 2])\n            ),\n        )\n        .add_inertial()\n        .add_visual()\n        .add_collision()\n        .build()\n    )\n\n    upper_pose = rod.builder.primitives.PrimitiveBuilder.build_pose(\n        pos=np.array([0, 0, upper_height / 2])\n    )\n\n    upper = (\n        upper_builder.build_link(\n            name=upper_builder.name,\n            pose=rod.builder.primitives.PrimitiveBuilder.build_pose(\n                relative_to=base.name, pos=np.array([0, 0, upper_height])\n            ),\n        )\n        .add_inertial(pose=upper_pose)\n        .add_visual(pose=upper_pose)\n        .add_collision(pose=upper_pose)\n        .build()\n    )\n\n    rod_model = rod.Sdf(\n        version=\"1.10\",\n        model=rod.Model(\n            name=\"single_pendulum\",\n            link=[base, upper],\n            joint=[fixed, pivot],\n        ),\n    )\n\n    rod_model.model.resolve_frames()\n\n    urdf_string = rod.urdf.exporter.UrdfExporter(pretty=True).to_urdf_string(\n        sdf=rod_model.models()[0]\n    )\n\n    model = build_jaxsim_model(model_description=urdf_string)\n\n    return model\n\n\n@pytest.fixture(scope=\"session\")\ndef jaxsim_model_garpez() -> js.model.JaxSimModel:\n    \"\"\"Fixture to create the original (unscaled) Garpez model.\"\"\"\n\n    rod_model = create_scalable_garpez_model()\n\n    urdf_string = rod.urdf.exporter.UrdfExporter(pretty=True).to_urdf_string(\n        sdf=rod_model\n    )\n\n    return build_jaxsim_model(model_description=urdf_string)\n\n\n@pytest.fixture(scope=\"session\")\ndef jaxsim_model_garpez_scaled(request) -> js.model.JaxSimModel:\n    \"\"\"Fixture to create the scaled version of the Garpez model.\"\"\"\n\n    # Get the link scales from the request.\n    link1_scale = request.param.get(\"link1_scale\", 1.0)\n    link2_scale = request.param.get(\"link2_scale\", 1.0)\n    link3_scale = request.param.get(\"link3_scale\", 1.0)\n    link4_scale = request.param.get(\"link4_scale\", 1.0)\n\n    rod_model = create_scalable_garpez_model(\n        link1_scale=link1_scale,\n        link2_scale=link2_scale,\n        link3_scale=link3_scale,\n        link4_scale=link4_scale,\n    )\n\n    urdf_string = rod.urdf.exporter.UrdfExporter(pretty=True).to_urdf_string(\n        sdf=rod_model\n    )\n\n    return build_jaxsim_model(model_description=urdf_string)\n\n\ndef create_scalable_garpez_model(\n    link1_scale: float = 1.0,\n    link2_scale: float = 1.0,\n    link3_scale: float = 1.0,\n    link4_scale: float = 1.0,\n) -> rod.Model:\n    \"\"\"\n    Build a scalable rod model to test parameterization and scaling.\n\n    Args:\n        link1_scale: Scale factor for link 1.\n        link2_scale: Scale factor for link 2.\n        link3_scale: Scale factor for link 3.\n        link4_scale: Scale factor for link 4.\n\n    Returns:\n        A rod model with the specified link scales.\n\n    Note:\n        The model is built assuming a constant link density, hence scaling the link will also have an impact on the link mass.\n    \"\"\"\n\n    import numpy as np\n    from rod.builder import primitives\n\n    # ========================\n    # Create the link builders\n    # ========================\n\n    density = 1000.0  # Fixed density in kg/m^3\n\n    l1_x, l1_y, l1_z = 0.3 * link1_scale, 0.2, 0.2\n    l1_volume = l1_x * l1_y * l1_z\n    l1_mass = density * l1_volume\n    link1_builder = primitives.BoxBuilder(\n        name=\"link1\", mass=l1_mass, x=l1_x, y=l1_y, z=l1_z\n    )\n\n    l2_radius = 0.1 * link2_scale\n    l2_volume = 4 / 3 * np.pi * l2_radius**3\n    l2_mass = density * l2_volume\n    link2_builder = primitives.SphereBuilder(\n        name=\"link2\", mass=l2_mass, radius=l2_radius\n    )\n\n    l3_radius = 0.05\n    l3_length = 0.5 * link3_scale\n    l3_volume = np.pi * l3_radius**2 * l3_length\n    l3_mass = density * l3_volume\n    link3_builder = primitives.CylinderBuilder(\n        name=\"link3\", mass=l3_mass, radius=l3_radius, length=l3_length\n    )\n\n    l4_x, l4_y, l4_z = 0.3 * link4_scale, 0.2, 0.1\n    l4_volume = l4_x * l4_y * l4_z\n    l4_mass = density * l4_volume\n    link4_builder = primitives.BoxBuilder(\n        name=\"link4\", mass=l4_mass, x=l4_x, y=l4_y, z=l4_z\n    )\n\n    # =================\n    # Create the joints\n    # =================\n\n    link1_to_link2 = rod.Joint(\n        name=\"link1_to_link2\",\n        type=\"revolute\",\n        parent=link1_builder.name,\n        child=link2_builder.name,\n        pose=primitives.PrimitiveBuilder.build_pose(\n            relative_to=link1_builder.name,\n            pos=np.array([link1_builder.x, link1_builder.y / 2, link1_builder.z / 2]),\n        ),\n        axis=rod.Axis(xyz=rod.Xyz(xyz=[0, 1, 0]), limit=rod.Limit()),\n    )\n\n    link2_to_link3 = rod.Joint(\n        name=\"link2_to_link3\",\n        type=\"revolute\",\n        parent=link2_builder.name,\n        child=link3_builder.name,\n        pose=primitives.PrimitiveBuilder.build_pose(\n            relative_to=link2_builder.name,\n            pos=np.array([link2_builder.radius, 0, -link2_builder.radius]),\n        ),\n        axis=rod.Axis(xyz=rod.Xyz(xyz=[0, 0, 1]), limit=rod.Limit()),\n    )\n\n    link3_to_link4 = rod.Joint(\n        name=\"link3_to_link4\",\n        type=\"revolute\",\n        parent=link3_builder.name,\n        child=link4_builder.name,\n        pose=primitives.PrimitiveBuilder.build_pose(\n            relative_to=link3_builder.name,\n            pos=np.array([-link3_builder.radius, 0, -link3_builder.length]),\n        ),\n        axis=rod.Axis(xyz=rod.Xyz(xyz=[1, 0, 0]), limit=rod.Limit()),\n    )\n\n    # ================\n    # Create the links\n    # ================\n\n    link1_elements_pose = primitives.PrimitiveBuilder.build_pose(\n        pos=np.array([link1_builder.x, link1_builder.y, link1_builder.z]) / 2\n    )\n\n    link1 = (\n        link1_builder.build_link(\n            name=link1_builder.name,\n            pose=primitives.PrimitiveBuilder.build_pose(relative_to=\"__model__\"),\n        )\n        .add_inertial(pose=link1_elements_pose)\n        .add_visual(pose=link1_elements_pose)\n        .add_collision(pose=link1_elements_pose)\n        .build()\n    )\n\n    link2_elements_pose = primitives.PrimitiveBuilder.build_pose(\n        pos=np.array([link2_builder.radius, 0, 0])\n    )\n\n    link2 = (\n        link2_builder.build_link(\n            name=link2_builder.name,\n            pose=primitives.PrimitiveBuilder.build_pose(\n                relative_to=link1_to_link2.name\n            ),\n        )\n        .add_inertial(pose=link2_elements_pose)\n        .add_visual(pose=link2_elements_pose)\n        .add_collision(pose=link2_elements_pose)\n        .build()\n    )\n\n    link3_elements_pose = primitives.PrimitiveBuilder.build_pose(\n        pos=np.array([0, 0, -link3_builder.length / 2])\n    )\n\n    link3 = (\n        link3_builder.build_link(\n            name=link3_builder.name,\n            pose=primitives.PrimitiveBuilder.build_pose(\n                relative_to=link2_to_link3.name\n            ),\n        )\n        .add_inertial(pose=link3_elements_pose)\n        .add_visual(pose=link3_elements_pose)\n        .add_collision(pose=link3_elements_pose)\n        .build()\n    )\n\n    link4_elements_pose = primitives.PrimitiveBuilder.build_pose(\n        # pos=np.array([0, 0, -link4_builder.z / 2])\n        pos=np.array([link4_builder.x / 2, 0, -link4_builder.z / 2])\n    )\n\n    link4 = (\n        link4_builder.build_link(\n            name=link4_builder.name,\n            pose=primitives.PrimitiveBuilder.build_pose(\n                relative_to=link3_to_link4.name\n            ),\n        )\n        .add_inertial(pose=link4_elements_pose)\n        .add_visual(pose=link4_elements_pose)\n        .add_collision(pose=link4_elements_pose)\n        .build()\n    )\n\n    # ===========\n    # Build model\n    # ===========\n\n    # Create model\n    rod_model = rod.Model(\n        name=\"model_demo\",\n        canonical_link=link1.name,\n        link=[link1, link2, link3, link4],\n        joint=[link1_to_link2, link2_to_link3, link3_to_link4],\n    )\n\n    rod_model.switch_frame_convention(\n        frame_convention=rod.FrameConvention.Urdf,\n        explicit_frames=True,\n        attach_frames_to_links=True,\n    )\n\n    assert rod.Sdf(model=rod_model, version=\"1.10\").serialize(validate=True)\n\n    return rod_model\n\n\ndef create_model_with_missing_collision() -> rod.Model:\n    \"\"\"\n    Build a rod model with a link that has a visual but no collision element.\n\n    This model is used to test the export logic when collision elements are missing.\n\n    Returns:\n        A rod model with one link missing a collision element.\n    \"\"\"\n\n    import numpy as np\n    from rod.builder import primitives\n\n    density = 1000.0  # Fixed density in kg/m^3\n\n    # Create link1 with both visual and collision\n    l1_x, l1_y, l1_z = 0.3, 0.2, 0.2\n    l1_volume = l1_x * l1_y * l1_z\n    l1_mass = density * l1_volume\n    link1_builder = primitives.BoxBuilder(\n        name=\"link1\", mass=l1_mass, x=l1_x, y=l1_y, z=l1_z\n    )\n\n    # Create link2 with visual but WITHOUT collision\n    l2_radius = 0.1\n    l2_volume = 4 / 3 * np.pi * l2_radius**3\n    l2_mass = density * l2_volume\n    link2_builder = primitives.SphereBuilder(\n        name=\"link2\", mass=l2_mass, radius=l2_radius\n    )\n\n    # Create joint\n    link1_to_link2 = rod.Joint(\n        name=\"link1_to_link2\",\n        type=\"revolute\",\n        parent=link1_builder.name,\n        child=link2_builder.name,\n        pose=primitives.PrimitiveBuilder.build_pose(\n            relative_to=link1_builder.name,\n            pos=np.array([link1_builder.x, link1_builder.y / 2, link1_builder.z / 2]),\n        ),\n        axis=rod.Axis(xyz=rod.Xyz(xyz=[0, 1, 0]), limit=rod.Limit()),\n    )\n\n    # Build link1 with visual and collision\n    link1_elements_pose = primitives.PrimitiveBuilder.build_pose(\n        pos=np.array([link1_builder.x, link1_builder.y, link1_builder.z]) / 2\n    )\n\n    link1 = (\n        link1_builder.build_link(\n            name=link1_builder.name,\n            pose=primitives.PrimitiveBuilder.build_pose(relative_to=\"__model__\"),\n        )\n        .add_inertial(pose=link1_elements_pose)\n        .add_visual(pose=link1_elements_pose)\n        .add_collision(pose=link1_elements_pose)\n        .build()\n    )\n\n    # Build link2 with visual but NO collision\n    link2_elements_pose = primitives.PrimitiveBuilder.build_pose(\n        pos=np.array([link2_builder.radius, 0, 0])\n    )\n\n    link2 = (\n        link2_builder.build_link(\n            name=link2_builder.name,\n            pose=primitives.PrimitiveBuilder.build_pose(\n                relative_to=link1_to_link2.name\n            ),\n        )\n        .add_inertial(pose=link2_elements_pose)\n        .add_visual(pose=link2_elements_pose)\n        # Note: NO .add_collision() call here\n        .build()\n    )\n\n    # Create model\n    rod_model = rod.Model(\n        name=\"model_missing_collision\",\n        canonical_link=link1.name,\n        link=[link1, link2],\n        joint=[link1_to_link2],\n    )\n\n    rod_model.switch_frame_convention(\n        frame_convention=rod.FrameConvention.Urdf,\n        explicit_frames=True,\n        attach_frames_to_links=True,\n    )\n\n    assert rod.Sdf(model=rod_model, version=\"1.10\").serialize(validate=True)\n\n    return rod_model\n\n\n@pytest.fixture(scope=\"session\")\ndef jaxsim_model_missing_collision() -> js.model.JaxSimModel:\n    \"\"\"\n    Fixture to create a model with a link that has a visual but no collision element.\n\n    This is used to test the export logic when collision elements are missing.\n    \"\"\"\n\n    rod_model = create_model_with_missing_collision()\n\n    urdf_string = rod.urdf.exporter.UrdfExporter(pretty=True).to_urdf_string(\n        sdf=rod_model\n    )\n\n    return build_jaxsim_model(model_description=urdf_string)\n\n\n@pytest.fixture(scope=\"session\")\ndef jaxsim_model_double_pendulum() -> js.model.JaxSimModel:\n    \"\"\"\n    Fixture providing the JaxSim model of a double pendulum.\n    Returns:\n        The JaxSim model of a double pendulum.\n    \"\"\"\n\n    model_path = pathlib.Path(__file__).parent / \"assets\" / \"double_pendulum.sdf\"\n    rod_model = load_model_from_file(model_path)\n    model = build_jaxsim_model(model_description=rod_model)\n\n    return model\n\n\n@pytest.fixture(scope=\"session\")\ndef jaxsim_model_cartpole() -> js.model.JaxSimModel:\n    \"\"\"\n    Fixture providing the JaxSim model of a cartpole.\n    Returns:\n        The JaxSim model of a cartpole.\n    \"\"\"\n\n    model_path = (\n        pathlib.Path(__file__).parent.parent / \"examples\" / \"assets\" / \"cartpole.urdf\"\n    )\n    rod_model = load_model_from_file(model_path, is_urdf=True)\n    model = build_jaxsim_model(model_description=rod_model)\n\n    return model\n\n\n@pytest.fixture(scope=\"session\")\ndef jaxsim_model_4_bar_linkage() -> js.model.JaxSimModel:\n    \"\"\"\n    Fixture providing the JaxSim model of a 4-bar linkage (opened configuration).\n\n    Returns:\n        The JaxSim model of the 4-bar linkage.\n    \"\"\"\n\n    model_path = pathlib.Path(__file__).parent / \"assets\" / \"4_bar_opened.urdf\"\n    rod_model = load_model_from_file(model_path, is_urdf=True)\n    model = build_jaxsim_model(model_description=rod_model)\n\n    return model\n\n\n# ============================\n# Collections of JaxSim models\n# ============================\n\n\ndef get_jaxsim_model_fixture(\n    model_name: str, request: pytest.FixtureRequest\n) -> str | pathlib.Path:\n    \"\"\"\n    Get the fixture providing the JaxSim model of a robot.\n\n    Args:\n        model_name: The name of the model.\n        request: The request object.\n\n    Returns:\n        The JaxSim model of the robot.\n\n    \"\"\"\n\n    match model_name:\n        case \"box\":\n            return request.getfixturevalue(jaxsim_model_box.__name__)\n        case \"sphere\":\n            return request.getfixturevalue(jaxsim_model_sphere.__name__)\n        case \"ergocub\":\n            return request.getfixturevalue(jaxsim_model_ergocub.__name__)\n        case \"ergocub_reduced\":\n            return request.getfixturevalue(jaxsim_model_ergocub_reduced.__name__)\n        case \"ur10\":\n            return request.getfixturevalue(jaxsim_model_ur10.__name__)\n        case \"single_pendulum\":\n            return request.getfixturevalue(jaxsim_model_single_pendulum.__name__)\n        case \"garpez\":\n            return request.getfixturevalue(jaxsim_model_garpez.__name__)\n        case \"garpez_scaled\":\n            return request.getfixturevalue(jaxsim_model_garpez_scaled.__name__)\n        case _:\n            raise ValueError(model_name)\n\n\n@pytest.fixture(\n    scope=\"session\",\n    params=[\n        \"box\",\n        \"sphere\",\n        \"ur10\",\n        \"ergocub\",\n        \"ergocub_reduced\",\n    ],\n)\ndef jaxsim_models_all(request) -> pathlib.Path | str:\n    \"\"\"\n    Fixture providing the JaxSim models of all supported robots.\n    \"\"\"\n\n    model_name: str = request.param\n    return get_jaxsim_model_fixture(model_name=model_name, request=request)\n\n\n@pytest.fixture(\n    scope=\"session\",\n    params=[\n        \"box\",\n        \"ur10\",\n        \"ergocub_reduced\",\n    ],\n)\ndef jaxsim_models_types(request) -> pathlib.Path | str:\n    \"\"\"\n    Fixture providing JaxSim models of all types of supported robots.\n\n    Note:\n        At the moment, most of our tests use this fixture. It provides:\n        - A robot with no joints.\n        - A fixed-base robot.\n        - A floating-base robot.\n\n    \"\"\"\n\n    model_name: str = request.param\n    return get_jaxsim_model_fixture(model_name=model_name, request=request)\n\n\n@pytest.fixture(\n    scope=\"session\",\n    params=[\n        \"box\",\n        \"sphere\",\n    ],\n)\ndef jaxsim_models_no_joints(request) -> pathlib.Path | str:\n    \"\"\"\n    Fixture providing JaxSim models of robots with no joints.\n    \"\"\"\n\n    model_name: str = request.param\n    return get_jaxsim_model_fixture(model_name=model_name, request=request)\n\n\n@pytest.fixture(\n    scope=\"session\",\n    params=[\n        \"ergocub\",\n        \"ergocub_reduced\",\n    ],\n)\ndef jaxsim_models_floating_base(request) -> pathlib.Path | str:\n    \"\"\"\n    Fixture providing JaxSim models of floating-base robots.\n    \"\"\"\n\n    model_name: str = request.param\n    return get_jaxsim_model_fixture(model_name=model_name, request=request)\n\n\n@pytest.fixture(\n    scope=\"session\",\n    params=[\n        \"ur10\",\n    ],\n)\ndef jaxsim_models_fixed_base(request) -> pathlib.Path | str:\n    \"\"\"\n    Fixture providing JaxSim models of fixed-base robots.\n    \"\"\"\n\n    model_name: str = request.param\n    return get_jaxsim_model_fixture(model_name=model_name, request=request)\n\n\n@pytest.fixture(scope=\"function\")\ndef set_jax_32bit(monkeypatch):\n    \"\"\"\n    Fixture that temporarily sets JAX precision to 32-bit for the duration of the test.\n    \"\"\"\n\n    del globals()[\"jaxsim\"]\n    del globals()[\"js\"]\n\n    # Temporarily disable x64\n    monkeypatch.setenv(\"JAX_ENABLE_X64\", \"0\")\n\n\n@pytest.fixture(scope=\"function\")\ndef jaxsim_model_box_32bit(set_jax_32bit, request) -> js.model.JaxSimModel:\n    \"\"\"\n    Fixture providing the JaxSim model of a box with 32-bit precision.\n\n    Returns:\n        The JaxSim model of a box with 32-bit precision.\n\n    \"\"\"\n\n    return get_jaxsim_model_fixture(model_name=\"box\", request=request)\n"
  },
  {
    "path": "tests/test_actuation.py",
    "content": "import jax.numpy as jnp\nfrom numpy.testing import assert_array_less\n\nimport jaxsim.api as js\nimport jaxsim.rbda\nfrom jaxsim import VelRepr\n\nfrom .utils import assert_allclose\n\n\ndef test_tn_curve(jaxsim_model_single_pendulum: js.model.JaxSimModel):\n\n    model = jaxsim_model_single_pendulum\n    new_act_params = jaxsim.rbda.actuation.ActuationParams()\n\n    with new_act_params.editable(validate=False) as new_act_params:\n        new_act_params.torque_max = 10\n        new_act_params.omega_th = 1\n        new_act_params.omega_max = 2\n\n    with model.editable(validate=False) as model:\n        model.actuation_params = new_act_params\n\n    data = js.data.JaxSimModelData.build(\n        model=model,\n        velocity_representation=VelRepr.Inertial,\n    )\n\n    new_joint_velocities = 1.5 * jnp.ones(model.dofs())\n    joint_torques_0 = 30 * jnp.ones(model.dofs())\n\n    data_0 = data.replace(model=model, joint_velocities=new_joint_velocities)\n\n    τ_total = js.actuation_model.compute_resultant_torques(\n        model, data_0, joint_force_references=joint_torques_0\n    )\n\n    assert_array_less(τ_total, joint_torques_0)\n\n    new_joint_velocities = 2.5 * jnp.ones(model.dofs())\n    joint_torques_0 = 30 * jnp.ones(model.dofs())\n    data_0 = data.replace(model=model, joint_velocities=new_joint_velocities)\n\n    τ_total = js.actuation_model.compute_resultant_torques(\n        model, data_0, joint_force_references=joint_torques_0\n    )\n\n    assert_allclose(τ_total, 0.0)\n"
  },
  {
    "path": "tests/test_api_com.py",
    "content": "import jax\n\nimport jaxsim.api as js\nfrom jaxsim import VelRepr\n\nfrom . import utils\nfrom .utils import assert_allclose\n\n\ndef test_com_properties(\n    jaxsim_models_types: js.model.JaxSimModel,\n    velocity_representation: VelRepr,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model, key=subkey, velocity_representation=velocity_representation\n    )\n\n    kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)\n\n    # =====\n    # Tests\n    # =====\n\n    p_com_idt = kin_dyn.com_position()\n    p_com_js = js.com.com_position(model=model, data=data)\n\n    assert_allclose(p_com_idt, p_com_js)\n\n    J_Gh_idt = kin_dyn.centroidal_momentum_jacobian()\n    J_Gh_js = js.com.centroidal_momentum_jacobian(model=model, data=data)\n\n    assert_allclose(J_Gh_idt, J_Gh_js)\n\n    h_com_idt = kin_dyn.centroidal_momentum()\n    h_com_js = js.com.centroidal_momentum(model=model, data=data)\n\n    assert_allclose(h_com_idt, h_com_js)\n\n    M_com_locked_idt = kin_dyn.locked_centroidal_spatial_inertia()\n    M_com_locked_js = js.com.locked_centroidal_spatial_inertia(model=model, data=data)\n\n    assert_allclose(M_com_locked_idt, M_com_locked_js)\n\n    J_avg_com_idt = kin_dyn.average_centroidal_velocity_jacobian()\n    J_avg_com_js = js.com.average_centroidal_velocity_jacobian(model=model, data=data)\n\n    assert_allclose(J_avg_com_idt, J_avg_com_js)\n\n    v_avg_com_idt = kin_dyn.average_centroidal_velocity()\n    v_avg_com_js = js.com.average_centroidal_velocity(model=model, data=data)\n\n    assert_allclose(v_avg_com_idt, v_avg_com_js)\n\n    # https://github.com/gbionics/jaxsim/pull/117#discussion_r1535486123\n    if data.velocity_representation is not VelRepr.Body:\n        vl_com_idt = kin_dyn.com_velocity()\n        vl_com_js = js.com.com_linear_velocity(model=model, data=data)\n\n        assert_allclose(vl_com_idt, vl_com_js)\n\n    # iDynTree provides the bias acceleration in G[W] frame regardless of the velocity\n    # representation. JaxSim, instead, returns the bias acceleration in G[B] when the\n    # active representation is VelRepr.Body.\n    if data.velocity_representation is not VelRepr.Body:\n        G_v̇_bias_WG_idt = kin_dyn.com_bias_acceleration()\n        G_v̇_bias_WG_js = js.com.bias_acceleration(model=model, data=data)\n\n        assert_allclose(G_v̇_bias_WG_idt, G_v̇_bias_WG_js)\n"
  },
  {
    "path": "tests/test_api_contact.py",
    "content": "import jax\nimport jax.numpy as jnp\nimport rod\n\nimport jaxsim.api as js\nfrom jaxsim import VelRepr\n\nfrom .utils import assert_allclose\n\n\ndef test_contact_kinematics(\n    jaxsim_models_types: js.model.JaxSimModel,\n    velocity_representation: VelRepr,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model,\n        key=subkey,\n        velocity_representation=velocity_representation,\n    )\n\n    # Get the indices of the enabled collidable points.\n    indices_of_enabled_collidable_points = (\n        model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points\n    )\n\n    parent_link_idx_of_enabled_collidable_points = jnp.array(\n        model.kin_dyn_parameters.contact_parameters.body, dtype=int\n    )[indices_of_enabled_collidable_points]\n\n    # =====\n    # Tests\n    # =====\n\n    # Compute the pose of the implicit contact frame associated to the collidable points\n    # and the transforms of all links.\n    W_H_C = js.contact.transforms(model=model, data=data)\n    W_H_L = data._link_transforms\n\n    # Check that the orientation of the implicit contact frame matches with the\n    # orientation of the link to which the contact point is attached.\n    for contact_idx, index_of_parent_link in enumerate(\n        parent_link_idx_of_enabled_collidable_points\n    ):\n        assert_allclose(\n            W_H_C[contact_idx, 0:3, 0:3], W_H_L[index_of_parent_link][0:3, 0:3]\n        )\n\n    # Check that the origin of the implicit contact frame is located over the\n    # collidable point.\n    W_p_C = js.contact.collidable_point_positions(model=model, data=data)\n\n    assert_allclose(W_p_C, W_H_C[:, 0:3, 3])\n\n    # Compute the velocity of the collidable point.\n    # This quantity always matches with the linear component of the mixed 6D velocity\n    # of the implicit frame associated to the collidable point.\n    W_ṗ_C = js.contact.collidable_point_velocities(model=model, data=data)\n\n    # Compute the velocity of the collidable point using the contact Jacobian.\n    ν = data.generalized_velocity\n    CW_J_WC = js.contact.jacobian(model=model, data=data, output_vel_repr=VelRepr.Mixed)\n    CW_vl_WC = jnp.einsum(\"c6g,g->c6\", CW_J_WC, ν)[:, 0:3]\n\n    # Compare the two velocities.\n    assert_allclose(W_ṗ_C, CW_vl_WC)\n\n\ndef test_collidable_point_jacobians(\n    jaxsim_models_types: js.model.JaxSimModel,\n    velocity_representation: VelRepr,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model, key=subkey, velocity_representation=velocity_representation\n    )\n\n    # =====\n    # Tests\n    # =====\n\n    # Compute the velocity of the collidable points with a RBDA.\n    # This function always returns the linear part of the mixed velocity of the\n    # implicit frame C corresponding to the collidable point.\n    W_ṗ_C = js.contact.collidable_point_velocities(model=model, data=data)\n\n    # Compute the generalized velocity and the free-floating Jacobian of the frame C.\n    ν = data.generalized_velocity\n    CW_J_WC = js.contact.jacobian(model=model, data=data, output_vel_repr=VelRepr.Mixed)\n\n    # Compute the velocity of the collidable points using the Jacobians.\n    v_WC_from_jax = jax.vmap(lambda J, ν: J @ ν, in_axes=(0, None))(CW_J_WC, ν)\n\n    assert_allclose(W_ṗ_C, v_WC_from_jax[:, 0:3])\n\n\ndef test_contact_jacobian_derivative(\n    jaxsim_models_floating_base: js.model.JaxSimModel,\n    velocity_representation: VelRepr,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_floating_base\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model,\n        key=subkey,\n        velocity_representation=velocity_representation,\n    )\n\n    # Get the indices of the enabled collidable points.\n    indices_of_enabled_collidable_points = (\n        model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points\n    )\n\n    # Extract the parent link names and the poses of the contact points.\n    parent_link_names = js.link.idxs_to_names(\n        model=model,\n        link_indices=jnp.array(\n            model.kin_dyn_parameters.contact_parameters.body, dtype=int\n        )[indices_of_enabled_collidable_points],\n    )\n\n    L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[\n        indices_of_enabled_collidable_points\n    ]\n\n    # =====\n    # Tests\n    # =====\n\n    # Load the model in ROD.\n    rod_model = rod.Sdf.load(sdf=model.built_from).model\n\n    # Add dummy frames on the contact points.\n    for idx, link_name, L_p_C in zip(\n        indices_of_enabled_collidable_points, parent_link_names, L_p_Ci, strict=True\n    ):\n        rod_model.add_frame(\n            frame=rod.Frame(\n                name=f\"contact_point_{idx}\",\n                attached_to=link_name,\n                pose=rod.Pose(\n                    relative_to=link_name, pose=jnp.zeros(shape=(6,)).at[0:3].set(L_p_C)\n                ),\n            ),\n        )\n\n    # Rebuild the JaxSim model.\n    model_with_frames = js.model.JaxSimModel.build_from_model_description(\n        model_description=rod_model\n    )\n    model_with_frames = js.model.reduce(\n        model=model_with_frames, considered_joints=model.joint_names()\n    )\n\n    # Rebuild the JaxSim data.\n    data_with_frames = js.data.JaxSimModelData.build(\n        model=model_with_frames,\n        base_position=data.base_position,\n        base_quaternion=data.base_orientation,\n        joint_positions=data.joint_positions,\n        base_linear_velocity=data.base_velocity[0:3],\n        base_angular_velocity=data.base_velocity[3:6],\n        joint_velocities=data.joint_velocities,\n        velocity_representation=velocity_representation,\n    )\n\n    # Extract the indexes of the frames attached to the contact points.\n    frame_idxs = js.frame.names_to_idxs(\n        model=model_with_frames,\n        frame_names=(\n            f\"contact_point_{idx}\" for idx in indices_of_enabled_collidable_points\n        ),\n    )\n\n    # Check that the number of frames is correct.\n    assert len(frame_idxs) == len(parent_link_names)\n\n    # Compute the contact Jacobian derivative.\n    J̇_WC = js.contact.jacobian_derivative(\n        model=model_with_frames, data=data_with_frames\n    )\n\n    # Compute the contact Jacobian derivative using frames.\n    J̇_WF = jax.vmap(\n        js.frame.jacobian_derivative,\n        in_axes=(None, None),\n    )(model_with_frames, data, frame_index=frame_idxs)\n\n    # Compare the two Jacobians.\n    assert_allclose(J̇_WC, J̇_WF)\n"
  },
  {
    "path": "tests/test_api_data.py",
    "content": "import jax\nimport jax.numpy as jnp\nimport pytest\nfrom numpy.testing import assert_raises\n\nimport jaxsim.api as js\nfrom jaxsim import VelRepr\nfrom jaxsim.utils import Mutability\n\nfrom . import utils\nfrom .utils import assert_allclose\n\n\ndef test_data_valid(\n    jaxsim_models_all: js.model.JaxSimModel,\n):\n\n    model = jaxsim_models_all\n    data = js.data.JaxSimModelData.build(model=model)\n\n    assert data.valid(model=model)\n\n\ndef test_data_switch_velocity_representation(\n    jaxsim_models_types: js.model.JaxSimModel,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model, key=subkey, velocity_representation=VelRepr.Inertial\n    )\n\n    # =====\n    # Tests\n    # =====\n\n    new_base_linear_velocity = jnp.array([1.0, -2.0, 3.0])\n    old_base_linear_velocity = data._base_linear_velocity\n\n    # The following should not change the original `data` object since it raises.\n    with pytest.raises(RuntimeError):\n        with data.switch_velocity_representation(\n            velocity_representation=VelRepr.Inertial\n        ):\n            with data.mutable_context(mutability=Mutability.MUTABLE):\n                data._base_linear_velocity = new_base_linear_velocity\n            raise RuntimeError(\"This is raised on purpose inside this context\")\n\n    assert_allclose(data._base_linear_velocity, old_base_linear_velocity)\n\n    # The following instead should result to an updated `data` object.\n    with (\n        data.switch_velocity_representation(velocity_representation=VelRepr.Inertial),\n        data.mutable_context(mutability=Mutability.MUTABLE),\n    ):\n        data._base_linear_velocity = new_base_linear_velocity\n\n    assert_allclose(data._base_linear_velocity, new_base_linear_velocity)\n\n\ndef test_data_change_velocity_representation(\n    jaxsim_models_types: js.model.JaxSimModel,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model, key=subkey, velocity_representation=VelRepr.Inertial\n    )\n\n    # =====\n    # Tests\n    # =====\n\n    kin_dyn_inertial = utils.build_kindyncomputations_from_jaxsim_model(\n        model=model, data=data\n    )\n\n    with data.switch_velocity_representation(VelRepr.Mixed):\n        kin_dyn_mixed = utils.build_kindyncomputations_from_jaxsim_model(\n            model=model, data=data\n        )\n\n    with data.switch_velocity_representation(VelRepr.Body):\n        kin_dyn_body = utils.build_kindyncomputations_from_jaxsim_model(\n            model=model, data=data\n        )\n\n    assert_allclose(data.base_velocity, kin_dyn_inertial.base_velocity())\n\n    if not model.floating_base():\n        return\n\n    with data.switch_velocity_representation(VelRepr.Mixed):\n        assert_allclose(data.base_velocity, kin_dyn_mixed.base_velocity())\n        assert_raises(\n            AssertionError,\n            assert_allclose,\n            data.base_velocity[0:3],\n            data._base_linear_velocity,\n        )\n        assert_allclose(data.base_velocity[3:6], data._base_angular_velocity)\n\n    with data.switch_velocity_representation(VelRepr.Body):\n        assert_allclose(data.base_velocity, kin_dyn_body.base_velocity())\n        assert_raises(\n            AssertionError,\n            assert_allclose,\n            data.base_velocity[0:3],\n            data._base_linear_velocity,\n        )\n        assert_raises(\n            AssertionError,\n            assert_allclose,\n            data.base_velocity[3:6],\n            data._base_angular_velocity,\n        )\n"
  },
  {
    "path": "tests/test_api_frame.py",
    "content": "import jax\nimport jax.numpy as jnp\nimport pytest\nfrom jax.errors import JaxRuntimeError\nfrom numpy.testing import assert_array_equal\n\nimport jaxsim.api as js\nfrom jaxsim import VelRepr\nfrom jaxsim.math.quaternion import Quaternion\n\nfrom . import utils\nfrom .utils import assert_allclose\n\n\ndef test_frame_index(jaxsim_models_types: js.model.JaxSimModel):\n\n    model = jaxsim_models_types\n\n    # =====\n    # Tests\n    # =====\n\n    n_l = model.number_of_links()\n    n_f = len(model.frame_names())\n\n    for idx, frame_name in enumerate(model.frame_names()):\n        frame_index = n_l + idx\n        assert js.frame.name_to_idx(model=model, frame_name=frame_name) == frame_index\n        assert js.frame.idx_to_name(model=model, frame_index=frame_index) == frame_name\n        assert (\n            js.frame.idx_of_parent_link(model=model, frame_index=frame_index)\n            < model.number_of_links()\n        )\n\n    # See discussion in https://github.com/gbionics/jaxsim/pull/280\n    assert_array_equal(\n        js.frame.names_to_idxs(model=model, frame_names=model.frame_names()),\n        jnp.arange(n_l, n_l + n_f),\n    )\n\n    assert (\n        js.frame.idxs_to_names(\n            model=model,\n            frame_indices=tuple(\n                js.frame.names_to_idxs(\n                    model=model, frame_names=model.frame_names()\n                ).tolist()\n            ),\n        )\n        == model.frame_names()\n    )\n\n    with pytest.raises(ValueError):\n        _ = js.frame.name_to_idx(model=model, frame_name=\"non_existent_frame\")\n\n    with pytest.raises(JaxRuntimeError):\n        _ = js.frame.idx_to_name(model=model, frame_index=-1)\n\n    with pytest.raises(JaxRuntimeError):\n        _ = js.frame.idx_to_name(model=model, frame_index=n_l - 1)\n\n    with pytest.raises(JaxRuntimeError):\n        _ = js.frame.idx_to_name(model=model, frame_index=n_l + n_f)\n\n    with pytest.raises(JaxRuntimeError):\n        _ = js.frame.idx_of_parent_link(model=model, frame_index=-1)\n\n    with pytest.raises(JaxRuntimeError):\n        _ = js.frame.idx_of_parent_link(model=model, frame_index=n_l - 1)\n\n    with pytest.raises(JaxRuntimeError):\n        _ = js.frame.idx_of_parent_link(model=model, frame_index=n_l + n_f)\n\n\ndef test_frame_transforms(\n    jaxsim_models_types: js.model.JaxSimModel,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model, key=subkey, velocity_representation=VelRepr.Inertial\n    )\n\n    kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)\n\n    # Get all names of frames in the iDynTree model.\n    frame_names = [\n        frame.name\n        for frame in model.description.frames\n        if frame.name in kin_dyn.frame_names()\n    ]\n\n    # Skip some entry of models with many frames.\n    frame_names = [\n        name\n        for name in frame_names\n        if \"skin\" not in name or \"laser\" not in name or \"depth\" not in name\n    ]\n\n    # Get indices of frames.\n    frame_indices = tuple(\n        frame.index\n        for frame in model.description.frames\n        if frame.index is not None and frame.name in frame_names\n    )\n\n    # =====\n    # Tests\n    # =====\n\n    assert len(frame_indices) == len(frame_names)\n\n    for frame_name in frame_names:\n\n        W_H_F_js = js.frame.transform(\n            model=model,\n            data=data,\n            frame_index=js.frame.name_to_idx(model=model, frame_name=frame_name),\n        )\n        W_H_F_idt = kin_dyn.frame_transform(frame_name=frame_name)\n        assert_allclose(\n            W_H_F_js, W_H_F_idt, atol=1e-6, err_msg=f\"Mismatch in {frame_name}\"\n        )\n\n\ndef test_frame_jacobians(\n    jaxsim_models_types: js.model.JaxSimModel,\n    velocity_representation: VelRepr,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model, key=subkey, velocity_representation=velocity_representation\n    )\n\n    kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)\n\n    # Get all names of frames in the iDynTree model.\n    frame_names = [\n        frame.name\n        for frame in model.description.frames\n        if frame.name in kin_dyn.frame_names()\n    ]\n\n    # Lower the number of frames for models with many frames.\n    if model.name().lower() == \"ergocub\":\n        assert any(\"sole\" in name for name in frame_names)\n        frame_names = [name for name in frame_names if \"sole\" in name]\n\n    # Get indices of frames.\n    frame_indices = tuple(\n        frame.index\n        for frame in model.description.frames\n        if frame.index is not None and frame.name in frame_names\n    )\n\n    # =====\n    # Tests\n    # =====\n\n    assert len(frame_indices) == len(frame_names)\n\n    for frame_name, frame_index in zip(frame_names, frame_indices, strict=True):\n\n        J_WL_js = js.frame.jacobian(model=model, data=data, frame_index=frame_index)\n        J_WL_idt = kin_dyn.jacobian_frame(frame_name=frame_name)\n        assert_allclose(J_WL_js, J_WL_idt, err_msg=f\"Mismatch in {frame_name}\")\n\n    for frame_name, frame_index in zip(frame_names, frame_indices, strict=True):\n\n        v_WF_idt = kin_dyn.frame_velocity(frame_name=frame_name)\n        v_WF_js = js.frame.velocity(model=model, data=data, frame_index=frame_index)\n        assert_allclose(v_WF_js, v_WF_idt, err_msg=f\"Mismatch in {frame_name}\")\n\n\ndef test_frame_jacobian_derivative(\n    jaxsim_models_types: js.model.JaxSimModel,\n    velocity_representation: VelRepr,\n    prng_key: jax.Array,\n):\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model, key=subkey, velocity_representation=velocity_representation\n    )\n\n    kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)\n\n    # Get all names of frames in the iDynTree model.\n    frame_names = [\n        frame.name\n        for frame in model.description.frames\n        if frame.name in kin_dyn.frame_names()\n    ]\n\n    # Skip some entry of models with many frames.\n    frame_names = [\n        name\n        for name in frame_names\n        if \"skin\" not in name or \"laser\" not in name or \"depth\" not in name\n    ]\n\n    frame_idxs = js.frame.names_to_idxs(model=model, frame_names=tuple(frame_names))\n\n    # ===============\n    # Test against AD\n    # ===============\n\n    # Get the generalized velocity.\n    I_ν = data.generalized_velocity\n\n    # Compute J̇.\n    O_J̇_WF_I = jax.vmap(\n        lambda frame_index: js.frame.jacobian_derivative(\n            model=model, data=data, frame_index=frame_index\n        )\n    )(frame_idxs)\n\n    assert O_J̇_WF_I.shape == (len(frame_names), 6, 6 + model.dofs())\n\n    # Compute the plain Jacobian.\n    # This function will be used to compute the Jacobian derivative with AD.\n    def J(q, frame_idxs) -> jax.Array:\n        data_ad = js.data.JaxSimModelData.build(\n            model=model,\n            velocity_representation=data.velocity_representation,\n            base_position=q[:3],\n            base_quaternion=q[3:7],\n            joint_positions=q[7:],\n        )\n\n        O_J_ad_WF_I = jax.vmap(\n            lambda model, data, frame_index: js.frame.jacobian(\n                model=model, data=data, frame_index=frame_index\n            ),\n            in_axes=(None, None, 0),\n        )(model, data_ad, frame_idxs)\n\n        return O_J_ad_WF_I\n\n    def compute_q(data: js.data.JaxSimModelData) -> jax.Array:\n        q = jnp.hstack(\n            [\n                data.base_position,\n                data.base_orientation,\n                data.joint_positions,\n            ]\n        )\n\n        return q\n\n    def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array:\n        with data.switch_velocity_representation(VelRepr.Body):\n            B_ω_WB = data.base_velocity[3:6]\n\n        with data.switch_velocity_representation(VelRepr.Mixed):\n            W_ṗ_B = data.base_velocity[0:3]\n\n        W_Q̇_B = Quaternion.derivative(\n            quaternion=data.base_orientation,\n            omega=B_ω_WB,\n            omega_in_body_fixed=True,\n            K=0.0,\n        ).squeeze()\n\n        q̇ = jnp.hstack([W_ṗ_B, W_Q̇_B, data.joint_velocities])\n\n        return q̇\n\n    # Compute q and q̇.\n    q = compute_q(data)\n    q̇ = compute_q̇(data)\n\n    # Compute dJ/dt with AD.\n    dJ_dq = jax.jacfwd(J, argnums=0)(q, frame_idxs)\n    O_J̇_ad_WF_I = jnp.einsum(\"ijkq,q->ijk\", dJ_dq, q̇)\n\n    assert_allclose(O_J̇_WF_I, O_J̇_ad_WF_I)\n\n    # =====================\n    # Test against iDynTree\n    # =====================\n\n    # Compute the product J̇ν.\n    O_a_bias_WF = jax.vmap(\n        lambda O_J̇_WF_I, I_ν: O_J̇_WF_I @ I_ν,\n        in_axes=(0, None),\n    )(O_J̇_WF_I, I_ν)\n\n    # Compare the two computations.\n    for index, name in enumerate(frame_names):\n        J̇ν_idt = kin_dyn.frame_bias_acc(frame_name=name)\n        J̇ν_js = O_a_bias_WF[index]\n        assert_allclose(J̇ν_js, J̇ν_idt, err_msg=f\"Mismatch in {name}\")\n"
  },
  {
    "path": "tests/test_api_joint.py",
    "content": "import jax.numpy as jnp\nimport pytest\nfrom jax.errors import JaxRuntimeError\nfrom numpy.testing import assert_array_equal\n\nimport jaxsim.api as js\n\n\ndef test_joint_index(\n    jaxsim_models_types: js.model.JaxSimModel,\n):\n\n    model = jaxsim_models_types\n\n    # =====\n    # Tests\n    # =====\n\n    for idx, joint_name in enumerate(model.joint_names()):\n        assert js.joint.name_to_idx(model=model, joint_name=joint_name) == idx\n        assert js.joint.idx_to_name(model=model, joint_index=idx) == joint_name\n\n    # See discussion in https://github.com/gbionics/jaxsim/pull/280\n    assert_array_equal(\n        js.joint.names_to_idxs(model=model, joint_names=model.joint_names()),\n        jnp.arange(model.number_of_joints()),\n    )\n\n    assert (\n        js.joint.idxs_to_names(\n            model=model,\n            joint_indices=tuple(\n                js.joint.names_to_idxs(\n                    model=model, joint_names=model.joint_names()\n                ).tolist()\n            ),\n        )\n        == model.joint_names()\n    )\n\n    with pytest.raises(ValueError):\n        _ = js.joint.name_to_idx(model=model, joint_name=\"non_existent_joint\")\n\n    with pytest.raises(JaxRuntimeError):\n        _ = js.joint.idx_to_name(model=model, joint_index=-1)\n\n    with pytest.raises(IndexError):\n        _ = js.joint.idx_to_name(model=model, joint_index=model.number_of_joints())\n"
  },
  {
    "path": "tests/test_api_link.py",
    "content": "import jax\nimport jax.numpy as jnp\nimport pytest\nfrom jax.errors import JaxRuntimeError\nfrom numpy.testing import assert_array_equal\n\nimport jaxsim.api as js\nimport jaxsim.math\nfrom jaxsim import VelRepr\n\nfrom . import utils\nfrom .utils import assert_allclose\n\n\ndef test_link_index(\n    jaxsim_models_types: js.model.JaxSimModel,\n):\n\n    model = jaxsim_models_types\n\n    # =====\n    # Tests\n    # =====\n\n    for idx, link_name in enumerate(model.link_names()):\n        assert js.link.name_to_idx(model=model, link_name=link_name) == idx\n        assert js.link.idx_to_name(model=model, link_index=idx) == link_name\n\n    # See discussion in https://github.com/gbionics/jaxsim/pull/280\n    assert_array_equal(\n        js.link.names_to_idxs(model=model, link_names=model.link_names()),\n        jnp.arange(model.number_of_links()),\n    )\n\n    assert (\n        js.link.idxs_to_names(\n            model=model,\n            link_indices=tuple(\n                js.link.names_to_idxs(\n                    model=model, link_names=model.link_names()\n                ).tolist()\n            ),\n        )\n        == model.link_names()\n    )\n\n    with pytest.raises(ValueError):\n        _ = js.link.name_to_idx(model=model, link_name=\"non_existent_link\")\n\n    with pytest.raises(JaxRuntimeError):\n        _ = js.link.idx_to_name(model=model, link_index=-1)\n\n    with pytest.raises(IndexError):\n        _ = js.link.idx_to_name(model=model, link_index=model.number_of_links())\n\n\ndef test_link_inertial_properties(\n    jaxsim_models_types: js.model.JaxSimModel,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model,\n        key=subkey,\n        velocity_representation=VelRepr.Inertial,\n    )\n\n    kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)\n\n    # =====\n    # Tests\n    # =====\n\n    for link_name, link_idx in zip(\n        model.link_names(),\n        jnp.arange(model.number_of_links()),\n        strict=True,\n    ):\n        if link_name == model.base_link():\n            continue\n\n        assert_allclose(\n            js.link.mass(model=model, link_index=link_idx),\n            kin_dyn.link_mass(link_name=link_name),\n            err_msg=f\"Mismatch in {link_name}\",\n        )\n\n        assert_allclose(\n            js.link.spatial_inertia(model=model, link_index=link_idx),\n            kin_dyn.link_spatial_inertia(link_name=link_name),\n            err_msg=f\"Mismatch in {link_name}\",\n        )\n\n\ndef test_link_transforms(\n    jaxsim_models_types: js.model.JaxSimModel,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model,\n        key=subkey,\n        velocity_representation=VelRepr.Inertial,\n    )\n\n    kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)\n\n    # =====\n    # Tests\n    # =====\n\n    W_H_LL_model = data._link_transforms\n\n    W_H_LL_links = jax.vmap(\n        lambda idx: js.link.transform(model=model, data=data, link_index=idx)\n    )(jnp.arange(model.number_of_links()))\n\n    assert_allclose(W_H_LL_model, W_H_LL_links)\n\n    for W_H_L, link_name in zip(W_H_LL_links, model.link_names(), strict=True):\n\n        assert_allclose(\n            W_H_L,\n            kin_dyn.frame_transform(frame_name=link_name),\n            err_msg=f\"Mismatch in {link_name}\",\n        )\n\n\ndef test_link_jacobians(\n    jaxsim_models_types: js.model.JaxSimModel,\n    velocity_representation: VelRepr,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model,\n        key=subkey,\n        velocity_representation=velocity_representation,\n    )\n\n    kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)\n\n    # =====\n    # Tests\n    # =====\n\n    J_WL_links = jax.vmap(\n        lambda idx: js.link.jacobian(model=model, data=data, link_index=idx)\n    )(jnp.arange(model.number_of_links()))\n\n    for J_WL, link_name in zip(J_WL_links, model.link_names(), strict=True):\n        assert_allclose(\n            J_WL,\n            kin_dyn.jacobian_frame(frame_name=link_name),\n            err_msg=f\"Mismatch in {link_name}\",\n        )\n\n    # The following is true only in inertial-fixed representation.\n    J_WL_model = js.model.generalized_free_floating_jacobian(model=model, data=data)\n\n    assert_allclose(J_WL_model, J_WL_links)\n\n    for link_name, link_idx in zip(\n        model.link_names(),\n        jnp.arange(model.number_of_links()),\n        strict=True,\n    ):\n        v_WL_idt = kin_dyn.frame_velocity(frame_name=link_name)\n        v_WL_js = js.link.velocity(model=model, data=data, link_index=link_idx)\n\n        assert_allclose(v_WL_js, v_WL_idt, err_msg=f\"Mismatch in {link_name}\")\n\n    # Test conversion to a different output velocity representation.\n    for other_repr in {VelRepr.Inertial, VelRepr.Body, VelRepr.Mixed}.difference(\n        {data.velocity_representation}\n    ):\n\n        with data.switch_velocity_representation(other_repr):\n            kin_dyn_other_repr = utils.build_kindyncomputations_from_jaxsim_model(\n                model=model, data=data\n            )\n\n        for link_name, link_idx in zip(\n            model.link_names(),\n            jnp.arange(model.number_of_links()),\n            strict=True,\n        ):\n            v_WL_idt = kin_dyn_other_repr.frame_velocity(frame_name=link_name)\n            v_WL_js = js.link.velocity(\n                model=model, data=data, link_index=link_idx, output_vel_repr=other_repr\n            )\n\n            assert_allclose(v_WL_js, v_WL_idt, err_msg=f\"Mismatch in {link_name}\")\n\n\ndef test_link_bias_acceleration(\n    jaxsim_models_types: js.model.JaxSimModel,\n    velocity_representation: VelRepr,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model,\n        key=subkey,\n        velocity_representation=velocity_representation,\n    )\n\n    kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)\n\n    # =====\n    # Tests\n    # =====\n\n    for name, index in zip(\n        model.link_names(),\n        jnp.arange(model.number_of_links()),\n        strict=True,\n    ):\n        Jν_idt = kin_dyn.frame_bias_acc(frame_name=name)\n        Jν_js = js.link.bias_acceleration(model=model, data=data, link_index=index)\n\n        assert_allclose(Jν_js, Jν_idt, err_msg=f\"Mismatch in {name}\")\n\n    # Test that the conversion of the link bias acceleration works as expected.\n    match data.velocity_representation:\n\n        # We exclude the mixed representation because converting the acceleration is\n        # more complex than using the plain 6D transform matrix.\n        case VelRepr.Mixed:\n            pass\n\n        # Inertial-fixed to body-fixed conversion.\n        case VelRepr.Inertial:\n\n            W_H_L = data._link_transforms\n\n            W_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data)\n\n            with data.switch_velocity_representation(VelRepr.Body):\n\n                W_X_L = jax.vmap(\n                    lambda W_H_L: jaxsim.math.Adjoint.from_transform(transform=W_H_L)\n                )(W_H_L)\n\n                L_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data)\n\n                W_a_bias_WL_converted = jax.vmap(\n                    lambda W_X_L, L_a_bias_WL: W_X_L @ L_a_bias_WL\n                )(W_X_L, L_a_bias_WL)\n\n            assert_allclose(W_a_bias_WL, W_a_bias_WL_converted)\n\n        # Body-fixed to inertial-fixed conversion.\n        case VelRepr.Body:\n\n            W_H_L = data._link_transforms\n\n            L_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data)\n\n            with data.switch_velocity_representation(VelRepr.Inertial):\n\n                L_X_W = jax.vmap(\n                    lambda W_H_L: jaxsim.math.Adjoint.from_transform(\n                        transform=W_H_L, inverse=True\n                    )\n                )(W_H_L)\n\n                W_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data)\n\n                L_a_bias_WL_converted = jax.vmap(\n                    lambda L_X_W, W_a_bias_WL: L_X_W @ W_a_bias_WL\n                )(L_X_W, W_a_bias_WL)\n\n            assert_allclose(L_a_bias_WL, L_a_bias_WL_converted)\n\n\ndef test_link_jacobian_derivative(\n    jaxsim_models_types: js.model.JaxSimModel,\n    velocity_representation: VelRepr,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model,\n        key=subkey,\n        velocity_representation=velocity_representation,\n    )\n\n    # =====\n    # Tests\n    # =====\n\n    # Get the generalized velocity.\n    I_ν = data.generalized_velocity\n\n    # Compute J̇.\n    O_J̇_WL_I = jax.vmap(\n        lambda link_index: js.link.jacobian_derivative(\n            model=model, data=data, link_index=link_index\n        )\n    )(jnp.arange(model.number_of_links()))\n\n    # Compute the product J̇ν.\n    O_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data)\n\n    # Compare the two computations.\n    assert_allclose(jnp.einsum(\"l6g,g->l6\", O_J̇_WL_I, I_ν), O_a_bias_WL)\n\n    # Compute the plain Jacobian.\n    # This function will be used to compute the Jacobian derivative with AD.\n    # Given q, computing J̇ by AD-ing this function should work out-of-the-box with\n    # all velocity representations, that are handled internally when computing J.\n    def J(q) -> jax.Array:\n\n        data_ad = js.data.JaxSimModelData.build(\n            model=model,\n            velocity_representation=data.velocity_representation,\n            base_position=q[:3],\n            base_quaternion=q[3:7],\n            joint_positions=q[7:],\n        )\n\n        O_J_WL_I = js.model.generalized_free_floating_jacobian(\n            model=model, data=data_ad\n        )\n\n        return O_J_WL_I\n\n    def compute_q(data: js.data.JaxSimModelData) -> jax.Array:\n\n        q = jnp.hstack(\n            [data.base_position, data.base_orientation, data.joint_positions]\n        )\n\n        return q\n\n    def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array:\n\n        with data.switch_velocity_representation(VelRepr.Body):\n            B_ω_WB = data.base_velocity[3:6]\n\n        with data.switch_velocity_representation(VelRepr.Mixed):\n            W_ṗ_B = data.base_velocity[0:3]\n\n        W_Q̇_B = jaxsim.math.Quaternion.derivative(\n            quaternion=data.base_orientation,\n            omega=B_ω_WB,\n            omega_in_body_fixed=True,\n            K=0.0,\n        ).squeeze()\n\n        q̇ = jnp.hstack([W_ṗ_B, W_Q̇_B, data.joint_velocities])\n\n        return q̇\n\n    # Compute q and q̇.\n    q = compute_q(data)\n    q̇ = compute_q̇(data)\n\n    # Compute dJ/dt with AD.\n    dJ_dq = jax.jacfwd(J, argnums=0)(q)\n    O_J̇_ad_WL_I = jnp.einsum(\"ijkq,q->ijk\", dJ_dq, q̇)\n\n    assert_allclose(O_J̇_WL_I, O_J̇_ad_WL_I)\n    assert_allclose(\n        jnp.einsum(\"l6g,g->l6\", O_J̇_ad_WL_I, I_ν),\n        jnp.einsum(\"l6g,g->l6\", O_J̇_WL_I, I_ν),\n    )\n"
  },
  {
    "path": "tests/test_api_model.py",
    "content": "import pathlib\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport rod\n\nimport jaxsim.api as js\nimport jaxsim.math\nfrom jaxsim import VelRepr\n\nfrom . import utils\nfrom .utils import assert_allclose\n\n\ndef test_model_creation_and_reduction(\n    jaxsim_model_ergocub: js.model.JaxSimModel,\n    prng_key: jax.Array,\n):\n\n    model_full = jaxsim_model_ergocub\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data_full = js.data.random_model_data(\n        model=model_full,\n        key=subkey,\n        velocity_representation=VelRepr.Inertial,\n        base_pos_bounds=((0, 0, 0.8), (0, 0, 0.8)),\n    )\n\n    # =====\n    # Tests\n    # =====\n\n    # Check that the data of the full model is valid.\n    assert data_full.valid(model=model_full)\n\n    # Build the ROD model from the original description.\n    assert isinstance(model_full.built_from, str | pathlib.Path)\n    rod_sdf = rod.Sdf.load(sdf=model_full.built_from)\n    assert len(rod_sdf.models()) == 1\n\n    # Get all non-fixed joint names from the description.\n    joint_names_in_description = [\n        j.name for j in rod_sdf.models()[0].joints() if j.type != \"fixed\"\n    ]\n\n    # Check that all non-fixed joints are in the full model.\n    assert set(joint_names_in_description) == set(model_full.joint_names())\n\n    # ================\n    # Reduce the model\n    # ================\n\n    # Get the names of the joints to keep in the reduced model.\n    reduced_joints = tuple(\n        j\n        for j in model_full.joint_names()\n        if \"camera\" not in j\n        and \"neck\" not in j\n        and \"wrist\" not in j\n        and \"thumb\" not in j\n        and \"index\" not in j\n        and \"middle\" not in j\n        and \"ring\" not in j\n        and \"pinkie\" not in j\n        #\n        and \"elbow\" not in j\n        and \"shoulder\" not in j\n        and \"torso\" not in j\n        and \"r_knee\" not in j\n    )\n\n    # Reduce the model.\n    # Note: here we also specify a non-zero position of the removed joints.\n    # The process should take into account the corresponding joint transforms\n    # when the link-joint-link chains are lumped together.\n    model_reduced = js.model.reduce(\n        model=model_full,\n        considered_joints=reduced_joints,\n        locked_joint_positions=dict(\n            zip(\n                model_full.joint_names(),\n                data_full.joint_positions.tolist(),\n                strict=True,\n            )\n        ),\n    )\n\n    # Check DoFs.\n    assert model_full.dofs() != model_reduced.dofs()\n\n    # Check that all non-fixed joints are in the reduced model.\n    assert set(reduced_joints) == set(model_reduced.joint_names())\n\n    # Check that the reduced model maintains the same terrain of the full model.\n    assert model_full.terrain == model_reduced.terrain\n\n    # Check that the reduced model maintains the same contact model of the full model.\n    assert model_full.contact_model == model_reduced.contact_model\n\n    # Check that the reduced model maintains the same integration step of the full model.\n    assert model_full.time_step == model_reduced.time_step\n\n    joint_idxs = js.joint.names_to_idxs(\n        model=model_full, joint_names=model_reduced.joint_names()\n    )\n\n    # Build the data of the reduced model.\n    data_reduced = js.data.JaxSimModelData.build(\n        model=model_reduced,\n        base_position=data_full.base_position,\n        base_quaternion=data_full.base_orientation,\n        joint_positions=data_full.joint_positions[joint_idxs],\n        base_linear_velocity=data_full.base_velocity[0:3],\n        base_angular_velocity=data_full.base_velocity[3:6],\n        joint_velocities=data_full.joint_velocities[joint_idxs],\n        velocity_representation=data_full.velocity_representation,\n    )\n\n    # Check that the reduced model data is valid.\n    assert not data_reduced.valid(model=model_full)\n    assert data_reduced.valid(model=model_reduced)\n\n    # Check that the total mass is preserved.\n    assert_allclose(\n        js.model.total_mass(model=model_full), js.model.total_mass(model=model_reduced)\n    )\n\n    # Check that the CoM position is preserved.\n    assert_allclose(\n        js.com.com_position(model=model_full, data=data_full),\n        js.com.com_position(model=model_reduced, data=data_reduced),\n        atol=1e-6,\n    )\n\n    # Check that joint serialization works.\n    assert_allclose(data_full.joint_positions[joint_idxs], data_reduced.joint_positions)\n    assert_allclose(\n        data_full.joint_velocities[joint_idxs], data_reduced.joint_velocities\n    )\n\n    # Check that link transforms are preserved.\n    for link_name in model_reduced.link_names():\n        W_H_L_full = js.link.transform(\n            model=model_full,\n            data=data_full,\n            link_index=js.link.name_to_idx(model=model_full, link_name=link_name),\n        )\n        W_H_L_reduced = js.link.transform(\n            model=model_reduced,\n            data=data_reduced,\n            link_index=js.link.name_to_idx(model=model_reduced, link_name=link_name),\n        )\n        assert_allclose(W_H_L_full, W_H_L_reduced)\n\n    # Check that collidable point positions are preserved.\n    assert_allclose(\n        js.contact.collidable_point_positions(model=model_full, data=data_full),\n        js.contact.collidable_point_positions(model=model_reduced, data=data_reduced),\n    )\n\n    # =====================\n    # Test against iDynTree\n    # =====================\n\n    kin_dyn_full = utils.build_kindyncomputations_from_jaxsim_model(\n        model=model_full, data=data_full\n    )\n\n    kin_dyn_reduced = utils.build_kindyncomputations_from_jaxsim_model(\n        model=model_reduced, data=data_reduced\n    )\n\n    # Check that the total mass is preserved.\n    assert_allclose(kin_dyn_full.total_mass(), kin_dyn_reduced.total_mass())\n\n    # Check that the CoM position match.\n    assert_allclose(kin_dyn_full.com_position(), kin_dyn_reduced.com_position())\n    assert_allclose(\n        kin_dyn_full.com_position(),\n        js.com.com_position(model=model_reduced, data=data_reduced),\n    )\n    # Check that link transforms match.\n    for link_name in model_reduced.link_names():\n\n        assert_allclose(\n            kin_dyn_reduced.frame_transform(frame_name=link_name),\n            kin_dyn_full.frame_transform(frame_name=link_name),\n            err_msg=f\"Mismatch in link {link_name}\",\n        )\n\n        assert_allclose(\n            kin_dyn_reduced.frame_transform(frame_name=link_name),\n            js.link.transform(\n                model=model_reduced,\n                data=data_reduced,\n                link_index=js.link.name_to_idx(\n                    model=model_reduced, link_name=link_name\n                ),\n            ),\n            err_msg=f\"Mismatch in link {link_name}\",\n        )\n\n    # Check that frame transforms match.\n    for frame_name in model_reduced.frame_names():\n\n        if frame_name not in kin_dyn_reduced.frame_names():\n            continue\n\n        # Skip some entry of models with many frames.\n        if \"skin\" in frame_name or \"laser\" in frame_name or \"depth\" in frame_name:\n            continue\n\n        assert_allclose(\n            kin_dyn_reduced.frame_transform(frame_name=frame_name),\n            kin_dyn_full.frame_transform(frame_name=frame_name),\n            err_msg=f\"Mismatch in frame {frame_name}\",\n        )\n\n        assert_allclose(\n            kin_dyn_reduced.frame_transform(frame_name=frame_name),\n            js.frame.transform(\n                model=model_reduced,\n                data=data_reduced,\n                frame_index=js.frame.name_to_idx(\n                    model=model_reduced, frame_name=frame_name\n                ),\n            ),\n            err_msg=f\"Mismatch in frame {frame_name}\",\n        )\n\n\ndef test_model_properties(\n    jaxsim_models_types: js.model.JaxSimModel,\n    velocity_representation: VelRepr,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model, key=subkey, velocity_representation=velocity_representation\n    )\n\n    kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)\n\n    # =====\n    # Tests\n    # =====\n\n    m_idt = kin_dyn.total_mass()\n    m_js = js.model.total_mass(model=model)\n    assert_allclose(m_idt, m_js)\n\n    J_Bh_idt = kin_dyn.total_momentum_jacobian()\n    J_Bh_js = js.model.total_momentum_jacobian(model=model, data=data)\n    assert_allclose(J_Bh_idt, J_Bh_js)\n\n    h_tot_idt = kin_dyn.total_momentum()\n    h_tot_js = js.model.total_momentum(model=model, data=data)\n    assert_allclose(h_tot_idt, h_tot_js)\n\n    M_locked_idt = kin_dyn.locked_spatial_inertia()\n    M_locked_js = js.model.locked_spatial_inertia(model=model, data=data)\n    assert_allclose(M_locked_idt, M_locked_js)\n\n    J_avg_idt = kin_dyn.average_velocity_jacobian()\n    J_avg_js = js.model.average_velocity_jacobian(model=model, data=data)\n    assert_allclose(J_avg_idt, J_avg_js)\n\n    v_avg_idt = kin_dyn.average_velocity()\n    v_avg_js = js.model.average_velocity(model=model, data=data)\n    assert_allclose(v_avg_idt, v_avg_js)\n\n\ndef test_model_rbda(\n    jaxsim_models_types: js.model.JaxSimModel,\n    prng_key: jax.Array,\n    velocity_representation: VelRepr,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model, key=subkey, velocity_representation=velocity_representation\n    )\n\n    kin_dyn = utils.build_kindyncomputations_from_jaxsim_model(model=model, data=data)\n\n    # =====\n    # Tests\n    # =====\n\n    # Support both fixed-base and floating-base models by slicing the first six rows.\n    sl = np.s_[0:] if model.floating_base() else np.s_[6:]\n\n    # Mass matrix\n    M_idt = kin_dyn.mass_matrix()\n    M_js = js.model.free_floating_mass_matrix(model=model, data=data)\n    assert_allclose(M_idt[sl, sl], M_js[sl, sl])\n\n    # Gravity forces\n    g_idt = kin_dyn.gravity_forces()\n    g_js = js.model.free_floating_gravity_forces(model=model, data=data)\n    assert_allclose(g_idt[sl], g_js[sl])\n\n    # Bias forces\n    h_idt = kin_dyn.bias_forces()\n    h_js = js.model.free_floating_bias_forces(model=model, data=data)\n    assert_allclose(h_idt[sl], h_js[sl])\n\n    # Forward kinematics\n    HH_js = data._link_transforms\n    HH_idt = jnp.stack(\n        [kin_dyn.frame_transform(frame_name=name) for name in model.link_names()]\n    )\n    assert_allclose(HH_idt, HH_js)\n\n    # Bias accelerations\n    Jν_js = js.model.link_bias_accelerations(model=model, data=data)\n    Jν_idt = jnp.stack(\n        [kin_dyn.frame_bias_acc(frame_name=name) for name in model.link_names()]\n    )\n    assert_allclose(Jν_idt, Jν_js)\n\n    # Mass matrix inverse via RBDA\n    M_inv_js = js.model.free_floating_mass_matrix_inverse(model=model, data=data)\n    M_inv_idt = jnp.linalg.inv(M_idt)\n    assert_allclose(M_inv_idt[sl, sl], M_inv_js[sl, sl])\n\n\ndef test_model_jacobian(\n    jaxsim_models_types: js.model.JaxSimModel,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    key, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model, key=subkey, velocity_representation=VelRepr.Inertial\n    )\n\n    # =====\n    # Tests\n    # =====\n\n    # Create random references (joint torques and link forces)\n    _, subkey1, subkey2 = jax.random.split(key, num=3)\n    references = js.references.JaxSimModelReferences.build(\n        model=model,\n        joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)),\n        link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)),\n        data=data,\n        velocity_representation=data.velocity_representation,\n    )\n\n    # Remove the force applied to the base link if the model is fixed-base\n    if not model.floating_base():\n        references = references.apply_link_forces(\n            forces=jnp.atleast_2d(jnp.zeros(6)),\n            model=model,\n            data=data,\n            link_names=(model.base_link(),),\n            additive=False,\n        )\n\n    # Get the J.T @ f product in inertial-fixed input/output representation.\n    # We use doubly right-trivialized jacobian with inertial-fixed 6D forces.\n    with (\n        references.switch_velocity_representation(VelRepr.Inertial),\n        data.switch_velocity_representation(VelRepr.Inertial),\n    ):\n\n        f = references.link_forces(model=model, data=data)\n        assert_allclose(f, references._link_forces)\n\n        J = js.model.generalized_free_floating_jacobian(model=model, data=data)\n        JTf_inertial = jnp.einsum(\"l6g,l6->g\", J, f)\n\n    for vel_repr in (VelRepr.Body, VelRepr.Mixed):\n        with references.switch_velocity_representation(vel_repr):\n\n            # Get the jacobian having an inertial-fixed input representation (so that\n            # it computes the same quantity computed above) and an output representation\n            # compatible with the frame in which the external forces are expressed.\n            with data.switch_velocity_representation(VelRepr.Inertial):\n\n                J = js.model.generalized_free_floating_jacobian(\n                    model=model, data=data, output_vel_repr=vel_repr\n                )\n\n            # Get the forces in the tested representation and compute the product\n            # O_J_WL_W.T @ O_f, producing a generalized acceleration in W.\n            # The resulting acceleration can be tested again the one computed before.\n            with data.switch_velocity_representation(vel_repr):\n\n                f = references.link_forces(model=model, data=data)\n                JTf_other = jnp.einsum(\"l6g,l6->g\", J, f)\n                assert_allclose(JTf_inertial, JTf_other, err_msg=vel_repr.name)\n\n\ndef test_coriolis_matrix(\n    jaxsim_models_types: js.model.JaxSimModel,\n    velocity_representation: VelRepr,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model, key=subkey, velocity_representation=velocity_representation\n    )\n\n    # =====\n    # Tests\n    # =====\n\n    I_ν = data.generalized_velocity\n    C = js.model.free_floating_coriolis_matrix(model=model, data=data)\n\n    h = js.model.free_floating_bias_forces(model=model, data=data)\n    g = js.model.free_floating_gravity_forces(model=model, data=data)\n    Cν = h - g\n\n    assert_allclose(C @ I_ν, Cν)\n\n    # Compute the free-floating mass matrix.\n    # This function will be used to compute the Ṁ with AD.\n    # Given q, computing Ṁ by AD-ing this function should work out-of-the-box with\n    # all velocity representations, that are handled internally when computing M.\n    def M(q) -> jax.Array:\n\n        data_ad = js.data.JaxSimModelData.build(\n            model=model,\n            velocity_representation=data.velocity_representation,\n            base_position=q[:3],\n            base_quaternion=q[3:7],\n            joint_positions=q[7:],\n        )\n\n        M = js.model.free_floating_mass_matrix(model=model, data=data_ad)\n\n        return M\n\n    def compute_q(data: js.data.JaxSimModelData) -> jax.Array:\n\n        q = jnp.hstack(\n            [data.base_position, data.base_orientation, data.joint_positions]\n        )\n\n        return q\n\n    def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array:\n\n        with data.switch_velocity_representation(VelRepr.Body):\n            B_ω_WB = data.base_velocity[3:6]\n\n        with data.switch_velocity_representation(VelRepr.Mixed):\n            W_ṗ_B = data.base_velocity[0:3]\n\n        W_Q̇_B = jaxsim.math.Quaternion.derivative(\n            quaternion=data.base_orientation,\n            omega=B_ω_WB,\n            omega_in_body_fixed=True,\n            K=0.0,\n        ).squeeze()\n\n        q̇ = jnp.hstack([W_ṗ_B, W_Q̇_B, data.joint_velocities])\n\n        return q̇\n\n    # Compute q and q̇.\n    q = compute_q(data)\n    q̇ = compute_q̇(data)\n\n    # Compute Ṁ with AD.\n    dM_dq = jax.jacfwd(M, argnums=0)(q)\n    Ṁ = jnp.einsum(\"ijq,q->ij\", dM_dq, q̇)\n\n    # We need to zero the blocks projecting joint variables to the base configuration\n    # for fixed-base models.\n    if not model.floating_base():\n        Ṁ = Ṁ.at[0:6, 6:].set(0)\n        Ṁ = Ṁ.at[6:, 0:6].set(0)\n\n    # Ensure that (Ṁ - 2C) is skew symmetric.\n    assert_allclose(Ṁ - C - C.T, 0.0)\n\n\ndef test_model_fd_id_consistency(\n    jaxsim_models_types: js.model.JaxSimModel,\n    velocity_representation: VelRepr,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    key, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model, key=subkey, velocity_representation=velocity_representation\n    )\n\n    # =====\n    # Tests\n    # =====\n\n    # Create random references (joint torques and link forces).\n    _, subkey1, subkey2 = jax.random.split(key, num=3)\n    references = js.references.JaxSimModelReferences.build(\n        model=model,\n        joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)),\n        link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)),\n        data=data,\n        velocity_representation=data.velocity_representation,\n    )\n\n    # Remove the force applied to the base link if the model is fixed-base.\n    if not model.floating_base():\n        references = references.apply_link_forces(\n            forces=jnp.atleast_2d(jnp.zeros(6)),\n            model=model,\n            data=data,\n            link_names=(model.base_link(),),\n            additive=False,\n        )\n\n    # Compute forward dynamics with ABA.\n    v̇_WB_aba, s̈_aba = js.model.forward_dynamics_aba(\n        model=model,\n        data=data,\n        joint_forces=references.joint_force_references(),\n        link_forces=references.link_forces(model=model, data=data),\n    )\n\n    # Compute forward dynamics with CRB.\n    v̇_WB_crb, s̈_crb = js.model.forward_dynamics_crb(\n        model=model,\n        data=data,\n        joint_forces=references.joint_force_references(),\n        link_forces=references.link_forces(model=model, data=data),\n    )\n\n    assert_allclose(s̈_aba, s̈_crb)\n    assert_allclose(v̇_WB_aba, v̇_WB_crb)\n\n    # Compute inverse dynamics with the quantities computed by forward dynamics\n    fB_id, τ_id = js.model.inverse_dynamics(\n        model=model,\n        data=data,\n        joint_accelerations=s̈_aba,\n        base_acceleration=v̇_WB_aba,\n        link_forces=references.link_forces(model=model, data=data),\n    )\n\n    # Check consistency between FD and ID\n    assert_allclose(τ_id, references.joint_force_references(model=model))\n    assert_allclose(fB_id, 0.0)\n\n    if model.floating_base():\n        # If we remove the base 6D force from the inputs, we should find it as output.\n        fB_id, τ_id = js.model.inverse_dynamics(\n            model=model,\n            data=data,\n            joint_accelerations=s̈_aba,\n            base_acceleration=v̇_WB_aba,\n            link_forces=references.link_forces(model=model, data=data)\n            .at[0]\n            .set(jnp.zeros(6)),\n        )\n\n        assert_allclose(τ_id, references.joint_force_references(model=model))\n        assert_allclose(fB_id, references.link_forces(model=model, data=data)[0])\n\n\ndef test_aba_vs_aba_parallel(\n    jaxsim_models_all: js.model.JaxSimModel,\n    velocity_representation: VelRepr,\n    prng_key: jax.Array,\n):\n    \"\"\"\n    Verify that the level-parallel ABA produces identical results to\n    the sequential ABA, both at the low-level RBDA and at the high-level\n    model API.\n    \"\"\"\n\n    model = jaxsim_models_all\n\n    key, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model, key=subkey, velocity_representation=velocity_representation\n    )\n\n    # Create random references.\n    _, subkey1, subkey2 = jax.random.split(key, num=3)\n    references = js.references.JaxSimModelReferences.build(\n        model=model,\n        joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)),\n        link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)),\n        data=data,\n        velocity_representation=data.velocity_representation,\n    )\n\n    if not model.floating_base():\n        references = references.apply_link_forces(\n            forces=jnp.atleast_2d(jnp.zeros(6)),\n            model=model,\n            data=data,\n            link_names=(model.base_link(),),\n            additive=False,\n        )\n\n    joint_forces = references.joint_force_references()\n    link_forces = references.link_forces(model=model, data=data)\n\n    v̇_WB_seq, s̈_seq = js.model.forward_dynamics_aba(\n        model=model,\n        data=data,\n        joint_forces=joint_forces,\n        link_forces=link_forces,\n        parallel=False,\n    )\n\n    v̇_WB_par, s̈_par = js.model.forward_dynamics_aba(\n        model=model,\n        data=data,\n        joint_forces=joint_forces,\n        link_forces=link_forces,\n        parallel=True,\n    )\n\n    assert_allclose(v̇_WB_seq, v̇_WB_par, atol=1e-9)\n    assert_allclose(s̈_seq, s̈_par, atol=1e-9)\n\n\ndef test_fk_vs_fk_parallel(\n    jaxsim_models_all: js.model.JaxSimModel,\n    velocity_representation: VelRepr,\n    prng_key: jax.Array,\n):\n    \"\"\"\n    Verify that the level-parallel FK produces identical results to\n    the sequential FK.\n    \"\"\"\n\n    model = jaxsim_models_all\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data = js.data.random_model_data(\n        model=model, key=subkey, velocity_representation=velocity_representation\n    )\n\n    W_H_seq = js.model.forward_kinematics(model=model, data=data, parallel=False)\n    W_H_par = js.model.forward_kinematics(model=model, data=data, parallel=True)\n\n    assert_allclose(W_H_seq, W_H_par, atol=1e-9)\n"
  },
  {
    "path": "tests/test_api_model_hw_parametrization.py",
    "content": "import pathlib\nimport xml.etree.ElementTree as ET\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport pytest\nimport rod\n\nimport jaxsim.api as js\nfrom jaxsim.api.kin_dyn_parameters import (\n    HwLinkMetadata,\n    LinkParametrizableShape,\n    ScalingFactors,\n)\nfrom jaxsim.rbda.contacts import SoftContactsParams\n\nfrom .utils import assert_allclose\n\n\ndef test_update_hw_link_parameters(jaxsim_model_garpez: js.model.JaxSimModel):\n    \"\"\"\n    Test that the hardware parameters of the model are updated correctly.\n    \"\"\"\n\n    model = jaxsim_model_garpez\n\n    # Store initial hardware parameters\n    initial_metadata = model.kin_dyn_parameters.hw_link_metadata\n\n    # Create the scaling factors\n    scaling_parameters = ScalingFactors(\n        dims=jnp.array(\n            [\n                [2.0, 1.5, 1.0],  # Scale x, y, z for link1\n                [1.2, 1.0, 1.0],  # Scale r for link2\n                [1.5, 0.8, 1.0],  # Scale r, l for link3\n                [1.5, 1.0, 0.8],  # Scale x, y, z for link4\n            ]\n        ),\n        density=jnp.ones(4),\n    )\n\n    # Update the model using the scaling factors\n    updated_model = js.model.update_hw_parameters(model, scaling_parameters)\n\n    # Compare updated hardware parameters\n    for link_idx, link_name in enumerate(model.link_names()):\n        updated_metadata = jax.tree.map(\n            lambda x, link_idx=link_idx: x[link_idx],\n            updated_model.kin_dyn_parameters.hw_link_metadata,\n        )\n        initial_metadata_link = jax.tree.map(\n            lambda x, link_idx=link_idx: x[link_idx], initial_metadata\n        )\n\n        # TODO: Compute the 3D scaling vector\n        # scale_vector = HwLinkMetadata._convert_scaling_to_3d_vector(\n        #     initial_metadata_link.shape, scaling_parameters.dims[link_idx]\n        # )\n\n        expected_link_dimensions = (\n            initial_metadata_link.geometry * scaling_parameters.dims[link_idx]\n        )\n\n        # Compare shape dimensions\n        assert_allclose(\n            updated_metadata.geometry,\n            expected_link_dimensions,\n            atol=1e-6,\n            err_msg=f\"Mismatch in dimensions for link {link_name}\",\n        )\n\n\n@pytest.mark.parametrize(\n    \"jaxsim_model_garpez_scaled\",\n    [\n        {\n            \"link1_scale\": 4.0,\n            \"link2_scale\": 3.0,\n            \"link3_scale\": 2.0,\n            \"link4_scale\": 1.5,\n        }\n    ],\n    indirect=True,\n)\ndef test_model_scaling_against_rod(\n    jaxsim_model_garpez: js.model.JaxSimModel,\n    jaxsim_model_garpez_scaled: js.model.JaxSimModel,\n):\n    \"\"\"\n    Test that scaling the HW parameters of JaxSim model matches the kin/dyn quantities of a JaxSim model obtained from a pre-scaled rod model.\n    \"\"\"\n\n    # Define scaling parameters\n    scaling_parameters = ScalingFactors(\n        dims=jnp.array(\n            [\n                [4.0, 1.0, 1.0],  # Scale only x-dimension for link1\n                [3.0, 1.0, 1.0],  # Scale only r-dimension for link2\n                [1.0, 2.0, 1.0],  # Scale l dimension for link3\n                [1.5, 1.0, 1.0],  # Scale only x-dimension for link4\n            ]\n        ),\n        density=jnp.ones(4),\n    )\n\n    # Apply scaling to the original JaxSim model\n    updated_model = js.model.update_hw_parameters(\n        jaxsim_model_garpez, scaling_parameters\n    )\n\n    # Compare hardware parameters of the scaled JaxSim model with the pre-scaled JaxSim model\n    scaled_metadata = updated_model.kin_dyn_parameters.hw_link_metadata\n\n    pre_scaled_metadata = jaxsim_model_garpez_scaled.kin_dyn_parameters.hw_link_metadata\n\n    # Compare shape dimensions\n    assert_allclose(scaled_metadata.geometry, pre_scaled_metadata.geometry, atol=1e-6)\n\n    # Compare mass\n    scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia(scaled_metadata)\n    pre_scaled_mass, _ = HwLinkMetadata.compute_mass_and_inertia(pre_scaled_metadata)\n\n    assert_allclose(scaled_mass, pre_scaled_mass, atol=1e-6)\n\n    # Compare inertia tensors\n    _, scaled_inertia = HwLinkMetadata.compute_mass_and_inertia(scaled_metadata)\n    _, pre_scaled_inertia = HwLinkMetadata.compute_mass_and_inertia(pre_scaled_metadata)\n\n    assert_allclose(scaled_inertia, pre_scaled_inertia, atol=1e-6)\n\n    # Compare transformations\n    assert_allclose(scaled_metadata.L_H_G, pre_scaled_metadata.L_H_G, atol=1e-6)\n    assert_allclose(scaled_metadata.L_H_vis, pre_scaled_metadata.L_H_vis, atol=1e-6)\n\n    # Compare collidable points positions\n    assert_allclose(\n        jaxsim_model_garpez_scaled.kin_dyn_parameters.contact_parameters.point,\n        updated_model.kin_dyn_parameters.contact_parameters.point,\n        atol=1e-6,\n    )\n\n\ndef test_update_hw_parameters_vmap(\n    jaxsim_model_garpez: js.model.JaxSimModel,\n):\n    \"\"\"\n    Test that the hardware parameters of the model are updated correctly using vmap\n    to create a set of n updated models.\n    \"\"\"\n\n    model_nominal = jaxsim_model_garpez\n    dofs = model_nominal.dofs()\n\n    # Define a set of scaling factors for n models\n    n = 10  # Number of updated models to create\n    scaling_factors = [\n        ScalingFactors(\n            dims=(scale * jnp.ones((model_nominal.number_of_links(), 3))),\n            density=(scale * jnp.ones(model_nominal.number_of_links())),\n        )\n        for scale in jnp.linspace(2.0, 2.0 + n - 1, n)\n    ]\n\n    # Convert the list of ScalingFactors to a JAX array of pytrees\n    scaling_factors = jax.tree.map(lambda *l: jnp.stack(l), *scaling_factors)\n\n    # Generate a batch of updated models using vmap\n    updated_models = jax.vmap(js.model.update_hw_parameters, in_axes=(None, 0))(\n        model_nominal,\n        scaling_factors,\n    )\n\n    def validate_model(updated_model):\n        assert updated_model is not None\n\n        # Compute forward kinematics for the \"link3\" link\n        H_link3 = js.link.transform(\n            model=updated_model,\n            data=js.data.JaxSimModelData.build(model=updated_model),\n            link_index=js.link.name_to_idx(model=updated_model, link_name=\"link3\"),\n        )\n\n        # Compute the mass matrix\n        M = js.model.free_floating_mass_matrix(\n            model=updated_model,\n            data=js.data.JaxSimModelData.build(model=updated_model),\n        )\n\n        assert H_link3 is not None\n        assert H_link3.shape == (4, 4)\n        assert M is not None\n        assert isinstance(M, jnp.ndarray)\n        assert M.shape == (6 + dofs, 6 + dofs)\n\n    # Use vmap to validate all updated models\n    jax.vmap(validate_model)(updated_models)\n\n\n@pytest.mark.parametrize(\n    \"jaxsim_model_garpez_scaled\",\n    [\n        {\n            \"link1_scale\": 4.0,\n            \"link2_scale\": 3.0,\n            \"link3_scale\": 2.0,\n            \"link4_scale\": 1.5,\n        }\n    ],\n    indirect=True,\n)\ndef test_export_updated_model(\n    jaxsim_model_garpez: js.model.JaxSimModel,\n    jaxsim_model_garpez_scaled: js.model.JaxSimModel,\n):\n    \"\"\"\n    Test the export of an updated model using JaxSimModel.export_updated_model.\n    \"\"\"\n\n    model: js.model.JaxSimModel = jaxsim_model_garpez\n\n    # Define scaling parameters\n    scaling_parameters = ScalingFactors(\n        dims=jnp.array(\n            [\n                [4.0, 1.0, 1.0],  # Scale x-dimension for link1\n                [3.0, 1.0, 1.0],  # Scale r-dimension for link2\n                [1.0, 2.0, 1.0],  # Scale l-dimension for link3\n                [1.5, 1.0, 1.0],  # Scale x-dimension for link4\n            ]\n        ),\n        density=jnp.ones(4),\n    )\n    identity_scaling = ScalingFactors(\n        dims=jnp.ones((model.number_of_links(), 3)),\n        density=jnp.ones(model.number_of_links()),\n    )\n\n    def get_link_by_name(model, name):\n        try:\n            return next(link for link in model.links() if link.name == name)\n        except StopIteration as err:\n            raise ValueError(\n                f\"Link '{name}' not found. Available links: {[l.name for l in model.links()]}\"\n            ) from err\n\n    def compare_geometries(exported_link, ref_link, label=\"\"):\n        exported_geom = exported_link.visual.geometry.geometry()\n        ref_geom = ref_link.visual.geometry.geometry()\n\n        attrs = [attr for attr in vars(exported_geom) if hasattr(ref_geom, attr)]\n        exported_vals = jnp.array([getattr(exported_geom, attr) for attr in attrs])\n        ref_vals = jnp.array([getattr(ref_geom, attr) for attr in attrs])\n        assert_allclose(\n            exported_vals,\n            ref_vals,\n            err_msg=f\"Geometry mismatch in {label} model.\",\n            atol=1e-6,\n        )\n\n    def compare_mass_and_inertia(exported_link, ref_link, label=\"\"):\n        assert_allclose(\n            exported_link.inertial.mass,\n            ref_link.inertial.mass,\n            atol=1e-4,\n            err_msg=f\"Mass mismatch in {label} model.\",\n        )\n        assert_allclose(\n            exported_link.inertial.inertia.matrix(),\n            ref_link.inertial.inertia.matrix(),\n            atol=1e-4,\n            err_msg=f\"Inertia matrix mismatch in {label} model.\",\n        )\n\n    def compare_collisions(exported_link, ref_link, label=\"\"):\n        geom_types = [\"box\", \"sphere\", \"cylinder\"]\n        for geom_type in geom_types:\n            exp_geom = getattr(exported_link.collision.geometry, geom_type)\n            ref_geom = getattr(ref_link.collision.geometry, geom_type)\n            if ref_geom is not None:\n                if geom_type == \"box\":\n                    assert_allclose(\n                        jnp.array(exp_geom.size), jnp.array(ref_geom.size), atol=1e-6\n                    )\n                elif geom_type == \"sphere\":\n                    assert_allclose(exp_geom.radius, ref_geom.radius, atol=1e-6)\n                elif geom_type == \"cylinder\":\n                    assert_allclose(exp_geom.radius, ref_geom.radius, atol=1e-6)\n                    assert_allclose(exp_geom.length, ref_geom.length, atol=1e-6)\n                return\n        pytest.skip(\n            f\"Collision geometry type for link {exported_link.name} not supported.\"\n        )\n\n    def validate_model(updated_model, ref_model, label):\n\n        urdf = updated_model.export_updated_model()\n        assert isinstance(urdf, str), f\"{label}: Exported URDF is not a string.\"\n\n        exported_sdf = rod.Sdf.load(urdf, is_urdf=True)\n\n        assert (\n            len(exported_sdf.models()) == 1\n        ), f\"{label}: Exported model does not contain exactly one ROD model.\"\n\n        exported_model = exported_sdf.models()[0]\n\n        for link_name in model.link_names():\n\n            exported_link = get_link_by_name(exported_model, link_name)\n            ref_link = get_link_by_name(ref_model, link_name)\n\n            compare_geometries(exported_link, ref_link, label=label)\n\n            compare_mass_and_inertia(exported_link, ref_link, label=label)\n\n            compare_collisions(exported_link, ref_link, label=label)\n\n    # Test both scaled and identity-scaled updates\n    for scaling, label in (\n        (scaling_parameters, \"SCALED\"),\n        (identity_scaling, \"IDENTITY SCALED\"),\n    ):\n        # Load reference ROD model\n        if label == \"IDENTITY SCALED\":\n            ref_model = rod.Sdf.load(jaxsim_model_garpez.built_from).models()[0]\n        else:\n            ref_model = rod.Sdf.load(jaxsim_model_garpez_scaled.built_from).models()[0]\n\n        updated_model = js.model.update_hw_parameters(model, scaling)\n        validate_model(updated_model, ref_model, label)\n\n\ndef test_hw_parameters_optimization(jaxsim_model_garpez: js.model.JaxSimModel):\n    \"\"\"\n    Test that updating hardware parameters allows optimizing the position of a link\n    to match a target value along a specific world axis.\n    \"\"\"\n\n    model = jaxsim_model_garpez\n    data = js.data.JaxSimModelData.build(model=model)\n\n    # Define the target height for the link.\n    target_height = 3.0\n\n    # Get the index of the link to optimize (e.g., \"torso\").\n    link_idx = js.link.name_to_idx(model, link_name=\"link4\")\n\n    # Define the initial hardware parameters (scaling factors).\n    initial_dims = jnp.ones(\n        (model.number_of_links(), 3)\n    )  # Initial dimensions (1.0 for all links).\n    initial_density = jnp.ones(\n        (model.number_of_links(),)\n    )  # Initial density (1.0 for all links).\n    scaling_factors = js.kin_dyn_parameters.ScalingFactors(\n        dims=initial_dims, density=initial_density\n    )\n\n    # Define the loss function.\n    def loss(scaling_factors):\n        # Update the model with the new hardware parameters.\n        updated_model = js.model.update_hw_parameters(\n            model=model, scaling_factors=scaling_factors\n        )\n\n        # Sync the data's cached kinematics (joint transforms, link transforms, …)\n        # with the updated model geometry before running any dynamics.\n        updated_data = data.replace(model=updated_model)\n\n        # Compute forward kinematics for the link.\n        W_H_L = js.model.forward_kinematics(model=updated_model, data=updated_data)[\n            link_idx\n        ]\n\n        # Extract the height (z-axis position) of the link.\n        link4_height = W_H_L[2, 3]  # Assuming z-axis is the third row.\n\n        # Compute the loss as the squared difference from the target height.\n        return (link4_height - target_height) ** 2\n\n    # Compute the gradient of the loss function with respect to the scaling factors.\n    loss_grad = jax.grad(loss)\n\n    # Perform gradient descent.\n    alpha = 0.01  # Learning rate.\n    num_iterations = 1000  # Number of gradient descent steps.\n    for _ in range(num_iterations):\n        # Compute the gradient.\n        grad_scaling_factors = loss_grad(scaling_factors)\n\n        # Update the scaling factors.\n        scaling_factors = js.kin_dyn_parameters.ScalingFactors(\n            dims=scaling_factors.dims - alpha * grad_scaling_factors.dims,\n            density=scaling_factors.density - alpha * grad_scaling_factors.density,\n        )\n\n        # Compute the current loss value.\n        current_loss = loss(scaling_factors)\n\n        # Optionally, print the progress.\n        if _ % 100 == 0:\n            print(f\"Iteration {_}: Loss = {current_loss}\")\n\n    # Assert that the final loss is close to zero.\n    assert current_loss < 1e-3, \"Optimization did not converge to the target height.\"\n\n\ndef test_hw_parameters_collision_scaling(\n    jaxsim_model_box: js.model.JaxSimModel, prng_key: jax.Array\n):\n    \"\"\"\n    Test that the collision elements of the model are updated correctly during the scaling of the model hw parameters.\n    \"\"\"\n\n    _, subkey = jax.random.split(prng_key, num=2)\n\n    # TODO: the jaxsim_model_box has an additional frame, which is handled wrongly\n    # during the export of the updated model. For this reason, we recreate the model\n    # from scratch here.\n    del jaxsim_model_box\n\n    import rod.builder.primitives\n\n    # Create on-the-fly a ROD model of a box.\n    rod_model = (\n        rod.builder.primitives.BoxBuilder(x=0.3, y=0.2, z=0.1, mass=1.0, name=\"box\")\n        .build_model()\n        .add_link(name=\"box_link\")\n        .add_inertial()\n        .add_visual()\n        .add_collision()\n        .build()\n    )\n\n    model = js.model.JaxSimModel.build_from_model_description(\n        model_description=rod_model\n    )\n\n    # Define the scaling factor for the model\n    scaling_factor = 5.0\n\n    # Recompute K and D, since the mass is scaled by scaling_factor^3\n    # and the expected static compression of the terrain is approximately\n    # proportional to mass/K and divided by the 4 contact points.\n    K = model.contact_params.K * (scaling_factor**2)\n\n    # Strongly overdamped, to avoid oscillations due to the high mass\n    # and the low penetration allowed.\n    D = 8 * jnp.sqrt(K)\n\n    with model.editable(validate=False) as model:\n        model.contact_params = SoftContactsParams(K=K, D=D)\n\n    # Define the nominal radius of the sphere\n    nominal_height = model.kin_dyn_parameters.hw_link_metadata.geometry[0, 2]\n\n    # Define scaling parameters\n    scaling_parameters = ScalingFactors(\n        dims=jnp.ones((model.number_of_links(), 3)) * scaling_factor,\n        density=jnp.array([1.0]),\n    )\n\n    # Update the model with the scaling parameters\n    updated_model = js.model.update_hw_parameters(model, scaling_parameters)\n\n    # Compute the expected height (nominal radius * scaling factor)\n    expected_height = nominal_height * scaling_factor / 2\n\n    # Simulate the box falling under gravity\n    data = js.data.JaxSimModelData.build(\n        model=updated_model,\n        # Set the initial position of the box's base to be slightly above the ground\n        # to allow it to settle at the expected height after scaling.\n        # The base position is set to the nominal height of the box scaled by the scaling factor,\n        # plus a small offset to avoid immediate collision with the ground.\n        # This ensures that the box has enough space to fall and settle at the expected height.\n        base_position=jnp.array(\n            [\n                *jax.random.uniform(subkey, shape=(2,)),\n                expected_height + 0.05,\n            ]\n        ),\n    )\n\n    num_steps = 1000  # Number of simulation steps\n\n    for _ in range(num_steps):\n        data = js.model.step(\n            model=updated_model,\n            data=data,\n        )\n\n    # Get the final height of the box's base\n    updated_base_height = data.base_position[2]\n\n    # Assert that the box settles at the expected height\n    assert jnp.isclose(\n        updated_base_height, expected_height, atol=1e-3\n    ), f\"model base height mismatch: expected {expected_height}, got {updated_base_height}\"\n\n\ndef test_unsupported_link_cases():\n    \"\"\"\n    Test that unsupported link cases are handled correctly.\n    \"\"\"\n    import rod.builder.primitives\n\n    from jaxsim.api.kin_dyn_parameters import LinkParametrizableShape\n\n    # Test unsupported (no visual)\n    no_visual_model = js.model.JaxSimModel.build_from_model_description(\n        rod.builder.primitives.BoxBuilder(x=1, y=1, z=1, mass=1, name=\"no_vis_box\")\n        .build_model()\n        .add_link(name=\"no_visual_link\")\n        .add_inertial()\n        .build()  # No .add_visual()\n    )\n    no_visual_metadata = no_visual_model.kin_dyn_parameters.hw_link_metadata\n    empty_metadata = HwLinkMetadata.empty()\n    comparison = jax.tree.map(\n        jnp.allclose,\n        no_visual_metadata,\n        empty_metadata,\n    )\n    assert jax.tree.reduce(\n        lambda acc, value: acc and bool(value), comparison, True\n    ), \"No links should be supported.\"\n\n    # Create a simple multi-link URDF and add collision to ensure links are kept\n    multi_link_urdf = \"\"\"\n        <?xml version=\"1.0\"?>\n        <robot name=\"two_link_test\">\n\n        <!-- Link 1: Supported (with box visual) -->\n        <link name=\"supported_link\">\n            <inertial>\n            <mass value=\"1.0\"/>\n            <inertia ixx=\"1\" iyy=\"1\" izz=\"1\" ixy=\"0\" ixz=\"0\" iyz=\"0\"/>\n            </inertial>\n            <visual>\n            <geometry>\n                <box size=\"1.0 1.0 1.0\"/>\n            </geometry>\n            </visual>\n            <collision>\n            <geometry>\n                <box size=\"1.0 1.0 1.0\"/>\n            </geometry>\n            </collision>\n        </link>\n\n        <!-- Link 2: Unsupported (no visual but has collision) -->\n        <link name=\"unsupported_link\">\n            <inertial>\n            <mass value=\"1.0\"/>\n            <inertia ixx=\"1\" iyy=\"1\" izz=\"1\" ixy=\"0\" ixz=\"0\" iyz=\"0\"/>\n            </inertial>\n            <!-- No visual element - this makes it unsupported -->\n            <collision>\n            <geometry>\n                <box size=\"0.5 0.5 0.5\"/>\n            </geometry>\n            </collision>\n        </link>\n\n        <!-- Link 3: Two visuals (first should be picked) -->\n        <link name=\"double_visual_link\">\n            <inertial>\n            <mass value=\"1.0\"/>\n            <inertia ixx=\"1\" iyy=\"1\" izz=\"1\" ixy=\"0\" ixz=\"0\" iyz=\"0\"/>\n            </inertial>\n            <visual name=\"primary_sphere\">\n            <geometry>\n                <sphere radius=\"0.4\"/>\n            </geometry>\n            </visual>\n            <visual name=\"secondary_box\">\n            <geometry>\n                <box size=\"0.8 0.2 0.2\"/>\n            </geometry>\n            </visual>\n            <collision>\n            <geometry>\n                <sphere radius=\"0.4\"/>\n            </geometry>\n            </collision>\n        </link>\n\n        <!-- Joint connecting the links -->\n        <joint name=\"connecting_joint\" type=\"revolute\">\n            <origin xyz=\"0.0 0.0 0.0\" rpy=\"0.0 0.0 0.0\"/>\n            <parent link=\"supported_link\"/>\n            <child link=\"unsupported_link\"/>\n            <axis xyz=\"1 0 0\"/>\n            <limit effort=\"3.4028235e+38\" velocity=\"3.4028235e+38\"/>\n        </joint>\n\n        <!-- Joint for double visual link -->\n        <joint name=\"double_visual_joint\" type=\"revolute\">\n            <origin xyz=\"0.1 0.0 0.0\" rpy=\"0.0 0.0 0.0\"/>\n            <parent link=\"unsupported_link\"/>\n            <child link=\"double_visual_link\"/>\n            <axis xyz=\"0 1 0\"/>\n            <limit effort=\"3.4028235e+38\" velocity=\"3.4028235e+38\"/>\n        </joint>\n\n        </robot>\n    \"\"\"\n\n    # Build JaxSim model from the URDF\n    multi_link_model = js.model.JaxSimModel.build_from_model_description(\n        multi_link_urdf, is_urdf=True\n    )\n    multi_link_metadata = multi_link_model.kin_dyn_parameters.hw_link_metadata\n\n    # Verify array consistency for the model\n    num_links = multi_link_model.number_of_links()\n    assert num_links == 3, f\"Expected 3 links in the URDF model, got {num_links}\"\n    assert (\n        len(multi_link_metadata.link_shape)\n        == len(multi_link_metadata.geometry)\n        == len(multi_link_metadata.density)\n        == num_links\n    )\n\n    # Count verification in single model\n    supported_count = sum(\n        1\n        for s in multi_link_metadata.link_shape\n        if s != LinkParametrizableShape.Unsupported\n    )\n    unsupported_count = sum(\n        1\n        for s in multi_link_metadata.link_shape\n        if s == LinkParametrizableShape.Unsupported\n    )\n\n    assert (\n        supported_count == 2\n    ), f\"Expected 2 supported links in single model, got {supported_count}\"\n    assert (\n        unsupported_count == 1\n    ), f\"Expected 1 unsupported link in single model, got {unsupported_count}\"\n\n    # Ensure shapes match expectations by name\n    link_indices = {name: idx for idx, name in enumerate(multi_link_model.link_names())}\n\n    assert (\n        multi_link_metadata.link_shape[link_indices[\"supported_link\"]]\n        == LinkParametrizableShape.Box\n    ), \"Supported link should remain a box\"\n    assert (\n        multi_link_metadata.link_shape[link_indices[\"unsupported_link\"]]\n        == LinkParametrizableShape.Unsupported\n    ), \"Unsupported link should remain unsupported\"\n\n    double_visual_idx = link_indices[\"double_visual_link\"]\n    assert (\n        multi_link_metadata.link_shape[double_visual_idx]\n        == LinkParametrizableShape.Sphere\n    ), \"Double visual link should pick the first (sphere) visual\"\n    assert_allclose(\n        multi_link_metadata.geometry[double_visual_idx, 0],\n        0.4,\n        err_msg=\"Sphere radius must match the first visual\",\n    )\n\n    # Test selective parametrization: only 'supported_link' and 'double_visual_link' should be parametrized\n    selective_model = js.model.JaxSimModel.build_from_model_description(\n        multi_link_urdf, is_urdf=True, parametrized_links=(\"double_visual_link\")\n    )\n    selective_metadata = selective_model.kin_dyn_parameters.hw_link_metadata\n\n    # Check that only the selected links are parametrized\n    link_indices = {name: idx for idx, name in enumerate(selective_model.link_names())}\n    assert (\n        selective_metadata.link_shape[link_indices[\"supported_link\"]]\n        == LinkParametrizableShape.Unsupported\n    ), \"Selected supported_link should be parametrized as Box\"\n    assert (\n        selective_metadata.link_shape[link_indices[\"double_visual_link\"]]\n        == LinkParametrizableShape.Sphere\n    ), \"Selected double_visual_link should be parametrized as Sphere\"\n    assert (\n        selective_metadata.link_shape[link_indices[\"unsupported_link\"]]\n        == LinkParametrizableShape.Unsupported\n    ), \"Non-selected unsupported_link should be marked as Unsupported\"\n\n\ndef test_export_continuous_joint_handling():\n    \"\"\"\n    Test that continuous joints are correctly exported with type=\"continuous\"\n    and without position limits, while preserving effort and velocity limits.\n    \"\"\"\n\n    # Load cartpole model which has a continuous joint (pivot)\n    cartpole_path = (\n        pathlib.Path(__file__).parent.parent / \"examples\" / \"assets\" / \"cartpole.urdf\"\n    )\n    model = js.model.JaxSimModel.build_from_model_description(cartpole_path)\n\n    # Define some simple scaling parameters (identity scaling)\n    scaling_parameters = ScalingFactors(\n        dims=jnp.ones((model.number_of_links(), 3)),\n        density=jnp.ones(model.number_of_links()),\n    )\n\n    # Update the model with scaling parameters\n    updated_model = js.model.update_hw_parameters(model, scaling_parameters)\n\n    # Export the updated model\n    exported_urdf = updated_model.export_updated_model()\n\n    # Parse the URDF XML directly (not through rod, which would convert continuous back to revolute)\n    root = ET.fromstring(exported_urdf)\n\n    # Find the pivot joint (continuous joint)\n    pivot_joint = None\n    for joint_elem in root.findall(\".//joint\"):\n        if joint_elem.get(\"name\") == \"pivot\":\n            pivot_joint = joint_elem\n            break\n\n    assert pivot_joint is not None, \"pivot joint should exist in exported model\"\n\n    # Verify that the joint type is \"continuous\"\n    assert (\n        pivot_joint.get(\"type\") == \"continuous\"\n    ), f\"pivot joint should have type='continuous', got '{pivot_joint.get('type')}'\"\n\n    # Verify that position limits are not present for continuous joints\n    limit_elem = pivot_joint.find(\"limit\")\n    assert limit_elem is not None, \"pivot joint should have limits element\"\n\n    assert (\n        limit_elem.get(\"lower\") is None\n    ), f\"continuous joint should not have lower position limit, got {limit_elem.get('lower')}\"\n    assert (\n        limit_elem.get(\"upper\") is None\n    ), f\"continuous joint should not have upper position limit, got {limit_elem.get('upper')}\"\n\n    # Verify that effort and velocity limits are preserved\n    assert (\n        limit_elem.get(\"effort\") is not None\n    ), \"continuous joint should preserve effort limit\"\n    assert (\n        limit_elem.get(\"velocity\") is not None\n    ), \"continuous joint should preserve velocity limit\"\n\n    # Verify that the linear joint (prismatic) is NOT changed to continuous\n    linear_joint = None\n    for joint_elem in root.findall(\".//joint\"):\n        if joint_elem.get(\"name\") == \"linear\":\n            linear_joint = joint_elem\n            break\n\n    assert linear_joint is not None, \"linear joint should exist in exported model\"\n    assert (\n        linear_joint.get(\"type\") == \"prismatic\"\n    ), f\"linear joint should remain prismatic, got '{linear_joint.get('type')}'\"\n\n    # Prismatic joint should keep its limits\n    linear_limit = linear_joint.find(\"limit\")\n    assert linear_limit is not None, \"prismatic joint should have limits\"\n    assert (\n        linear_limit.get(\"lower\") is not None\n    ), \"prismatic joint should have lower limit\"\n    assert (\n        linear_limit.get(\"upper\") is not None\n    ), \"prismatic joint should have upper limit\"\n\n\ndef test_export_model_with_missing_collision(\n    jaxsim_model_missing_collision: js.model.JaxSimModel,\n):\n    \"\"\"\n    Test that export_updated_model() works correctly when a link has a visual\n    but is missing a collision element.\n\n    This validates the skip logic that handles None collision elements.\n    \"\"\"\n\n    model = jaxsim_model_missing_collision\n\n    # Define scaling parameters to modify the model\n    scaling_parameters = ScalingFactors(\n        dims=jnp.array(\n            [\n                [1.5, 1.0, 1.0],  # Scale x-dimension for link1 (has collision)\n                [2.0, 1.0, 1.0],  # Scale radius for link2 (missing collision)\n            ]\n        ),\n        density=jnp.ones(2),\n    )\n\n    # Update the model with scaling parameters\n    updated_model = js.model.update_hw_parameters(model, scaling_parameters)\n\n    # Export the updated model - this should NOT fail even though link2 is missing collision\n    exported_urdf = updated_model.export_updated_model()\n\n    # Verify basic structure of exported URDF\n    assert isinstance(exported_urdf, str), \"Exported URDF should be a string\"\n    assert len(exported_urdf) > 0, \"Exported URDF should not be empty\"\n\n    # Parse the exported URDF to verify it's valid XML\n    root = ET.fromstring(exported_urdf)\n    assert root.tag == \"robot\", \"Root element should be 'robot'\"\n\n    # Find both links in the exported model\n    links = {link.get(\"name\"): link for link in root.findall(\".//link\")}\n    assert \"link1\" in links, \"link1 should exist in exported model\"\n    assert \"link2\" in links, \"link2 should exist in exported model\"\n\n    # Verify link1 has both visual and collision\n    link1 = links[\"link1\"]\n    link1_visual = link1.find(\"visual\")\n    link1_collision = link1.find(\"collision\")\n    assert link1_visual is not None, \"link1 should have a visual element\"\n    assert link1_collision is not None, \"link1 should have a collision element\"\n\n    # Verify link1's geometry was updated\n    link1_visual_box = link1_visual.find(\".//box\")\n    assert link1_visual_box is not None, \"link1 visual should have box geometry\"\n    link1_size = [float(x) for x in link1_visual_box.get(\"size\").split()]\n    # First dimension should be scaled by 1.5\n    assert_allclose(\n        link1_size[0],\n        0.3 * 1.5,\n        atol=1e-6,\n        err_msg=\"link1 x-dimension should be scaled\",\n    )\n\n    # Verify link2 has visual but no collision\n    link2 = links[\"link2\"]\n    link2_visual = link2.find(\"visual\")\n    link2_collision = link2.find(\"collision\")\n    assert link2_visual is not None, \"link2 should have a visual element\"\n    assert link2_collision is None, \"link2 should NOT have a collision element\"\n\n    # Verify link2's visual geometry was updated despite missing collision\n    link2_visual_sphere = link2_visual.find(\".//sphere\")\n    assert link2_visual_sphere is not None, \"link2 visual should have sphere geometry\"\n    link2_radius = float(link2_visual_sphere.get(\"radius\"))\n    # Radius should be scaled by 2.0\n    assert_allclose(\n        link2_radius,\n        0.1 * 2.0,\n        atol=1e-6,\n        err_msg=\"link2 radius should be scaled despite missing collision\",\n    )\n\n    # Load the exported model to verify it can be parsed correctly\n    exported_sdf = rod.Sdf.load(exported_urdf, is_urdf=True)\n    assert (\n        len(exported_sdf.models()) == 1\n    ), \"Exported model should contain exactly one ROD model\"\n\n    exported_model = exported_sdf.models()[0]\n    assert exported_model.name == model.name(), \"Exported model name should match\"\n\n    # Verify we can build a JaxSim model from the exported URDF\n    _ = js.model.JaxSimModel.build_from_model_description(\n        model_description=exported_urdf, is_urdf=True\n    )\n\n\ndef test_export_mesh_scaling_preserves_nonzero_visual_and_joint_origins(\n    tmp_path: pathlib.Path,\n):\n    \"\"\"\n    Regression test for mesh export:\n    non-identity scaling must preserve non-zero visual/joint origins in the URDF.\n    \"\"\"\n\n    mesh_file = pathlib.Path(__file__).parent / \"assets\" / \"cube.stl\"\n    if not mesh_file.exists():\n        pytest.skip(f\"Test mesh file not found: {mesh_file}\")\n\n    urdf_path = tmp_path / \"mesh_origin_regression.urdf\"\n    urdf_path.write_text(\n        f\"\"\"<?xml version=\"1.0\"?>\n<robot name=\"mesh_origin_regression\">\n  <link name=\"base_link\">\n    <inertial>\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\n      <mass value=\"1.0\"/>\n      <inertia ixx=\"0.01\" ixy=\"0\" ixz=\"0\" iyy=\"0.01\" iyz=\"0\" izz=\"0.01\"/>\n    </inertial>\n    <visual>\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\n      <geometry>\n        <box size=\"0.1 0.1 0.1\"/>\n      </geometry>\n    </visual>\n  </link>\n\n  <link name=\"mesh_link\">\n    <inertial>\n      <origin xyz=\"0.01 -0.01 0.02\" rpy=\"0 0 0\"/>\n      <mass value=\"0.01\"/>\n      <inertia ixx=\"1e-6\" ixy=\"0\" ixz=\"0\" iyy=\"1e-6\" iyz=\"0\" izz=\"1e-6\"/>\n    </inertial>\n    <visual>\n      <origin xyz=\"0.03 0.01 -0.02\" rpy=\"0 0 0\"/>\n      <geometry>\n        <mesh filename=\"{mesh_file.as_posix()}\" scale=\"0.001 0.001 0.001\"/>\n      </geometry>\n    </visual>\n    <collision>\n      <origin xyz=\"0.03 0.01 -0.02\" rpy=\"0 0 0\"/>\n      <geometry>\n        <mesh filename=\"{mesh_file.as_posix()}\" scale=\"0.001 0.001 0.001\"/>\n      </geometry>\n    </collision>\n  </link>\n\n  <joint name=\"base_to_mesh\" type=\"revolute\">\n    <parent link=\"base_link\"/>\n    <child link=\"mesh_link\"/>\n    <origin xyz=\"0.4 -0.1 0.2\" rpy=\"0 0 0\"/>\n    <axis xyz=\"0 0 1\"/>\n    <limit lower=\"-1.57\" upper=\"1.57\" effort=\"10\" velocity=\"1\"/>\n  </joint>\n</robot>\n\"\"\",\n        encoding=\"utf-8\",\n    )\n\n    model = js.model.JaxSimModel.build_from_model_description(\n        model_description=urdf_path,\n        is_urdf=True,\n        parametrized_links=(\"mesh_link\",),\n    )\n\n    mesh_link_idx = js.link.name_to_idx(model=model, link_name=\"mesh_link\")\n    dims = jnp.ones((model.number_of_links(), 3))\n    dims = dims.at[mesh_link_idx].set(jnp.array([1.7, 0.8, 1.3]))\n    scaling = ScalingFactors(dims=dims, density=jnp.ones(model.number_of_links()))\n\n    updated_model = js.model.update_hw_parameters(model=model, scaling_factors=scaling)\n    exported_urdf = updated_model.export_updated_model()\n    root = ET.fromstring(exported_urdf)\n\n    visual_origin = root.find(\".//link[@name='mesh_link']/visual/origin\")\n    assert visual_origin is not None, \"Mesh visual origin must exist in exported URDF\"\n    visual_xyz = np.array([float(v) for v in visual_origin.get(\"xyz\").split()])\n\n    expected_visual_xyz = np.array(\n        updated_model.kin_dyn_parameters.hw_link_metadata.L_H_vis[mesh_link_idx][:3, 3]\n    )\n    assert_allclose(\n        visual_xyz,\n        expected_visual_xyz,\n        atol=1e-8,\n        err_msg=\"Exported mesh visual origin does not match updated metadata\",\n    )\n    assert not np.allclose(visual_xyz, np.zeros(3), atol=1e-12)\n\n    joint_origin = root.find(\".//joint[@name='base_to_mesh']/origin\")\n    assert joint_origin is not None, \"Joint origin must exist in exported URDF\"\n    joint_xyz = np.array([float(v) for v in joint_origin.get(\"xyz\").split()])\n\n    joint_idx = js.joint.name_to_idx(model=updated_model, joint_name=\"base_to_mesh\")\n    expected_joint_xyz = np.array(\n        updated_model.kin_dyn_parameters.joint_model.λ_H_pre[joint_idx + 1][:3, 3]\n    )\n    assert_allclose(\n        joint_xyz,\n        expected_joint_xyz,\n        atol=1e-8,\n        err_msg=\"Exported joint origin does not match updated joint transform\",\n    )\n    assert not np.allclose(joint_xyz, np.zeros(3), atol=1e-12)\n\n    reimported_jaxsim_model = js.model.JaxSimModel.build_from_model_description(\n        model_description=exported_urdf, is_urdf=True\n    )\n    assert (\n        reimported_jaxsim_model is not None\n    ), \"Should be able to build model from exported URDF\"\n    assert (\n        reimported_jaxsim_model.number_of_links() == model.number_of_links()\n    ), \"Reimported model should have same number of links\"\n\n\n# =============================================================================\n# Mesh Scaling Tests\n# =============================================================================\n\n\ndef test_mesh_shape_enum():\n    \"\"\"Test that the Mesh shape type is available in the enum.\"\"\"\n    assert hasattr(LinkParametrizableShape, \"Mesh\")\n    assert LinkParametrizableShape.Mesh == 3\n\n\ndef test_mixed_shapes_metadata():\n    \"\"\"Test loading and metadata verification for mixed primitive and mesh shapes.\"\"\"\n    test_urdf = pathlib.Path(__file__).parent / \"assets\" / \"mixed_shapes_robot.urdf\"\n\n    if not test_urdf.exists():\n        pytest.skip(f\"Test URDF not found: {test_urdf}\")\n\n    mesh_file = pathlib.Path(__file__).parent / \"assets\" / \"cube.stl\"\n    if not mesh_file.exists():\n        pytest.skip(f\"Test mesh not found: {mesh_file}\")\n\n    # Load model with all link types parametrized\n    model = js.model.JaxSimModel.build_from_model_description(\n        model_description=test_urdf,\n        is_urdf=True,\n        parametrized_links=(\"box_link\", \"cylinder_link\", \"mesh_link\", \"sphere_link\"),\n    )\n\n    assert model.name() == \"mixed_shapes_robot\"\n    assert model.number_of_links() == 4\n\n    hw_meta = model.kin_dyn_parameters.hw_link_metadata\n\n    # Verify all 4 links are parametrized with correct shape types\n    assert len(hw_meta.link_shape) == 4\n    assert hw_meta.link_shape[0] == LinkParametrizableShape.Box\n    assert hw_meta.link_shape[1] == LinkParametrizableShape.Cylinder\n    assert hw_meta.link_shape[2] == LinkParametrizableShape.Mesh\n    assert hw_meta.link_shape[3] == LinkParametrizableShape.Sphere\n\n    # Verify mesh data exists only for mesh link\n    assert hw_meta.mesh_vertices is not None\n    assert hw_meta.mesh_vertices[0] is None  # box\n    assert hw_meta.mesh_vertices[1] is None  # cylinder\n    assert hw_meta.mesh_vertices[2] is not None  # mesh\n    assert hw_meta.mesh_vertices[3] is None  # sphere\n    assert hw_meta.mesh_faces is not None\n    assert hw_meta.mesh_faces[2] is not None  # mesh link has faces\n\n\ndef test_mixed_shapes_scaling():\n    \"\"\"Test uniform and non-uniform scaling with mixed primitive and mesh shapes.\"\"\"\n    test_urdf = pathlib.Path(__file__).parent / \"assets\" / \"mixed_shapes_robot.urdf\"\n\n    if not test_urdf.exists():\n        pytest.skip(f\"Test URDF not found: {test_urdf}\")\n\n    mesh_file = pathlib.Path(__file__).parent / \"assets\" / \"cube.stl\"\n    if not mesh_file.exists():\n        pytest.skip(f\"Test mesh not found: {mesh_file}\")\n\n    model = js.model.JaxSimModel.build_from_model_description(\n        model_description=test_urdf,\n        is_urdf=True,\n        parametrized_links=(\"box_link\", \"cylinder_link\", \"mesh_link\", \"sphere_link\"),\n    )\n\n    hw_meta = model.kin_dyn_parameters.hw_link_metadata\n    if len(hw_meta.link_shape) == 0:\n        pytest.skip(\"Hardware parametrization not supported\")\n\n    # Get original masses\n    masses_orig = {}\n    for i in range(model.number_of_links()):\n        link_name = js.link.idx_to_name(model=model, link_index=i)\n        masses_orig[link_name] = float(model.kin_dyn_parameters.link_parameters.mass[i])\n\n    # Test uniform scaling (2x), so all links should scaled by 8x\n    uniform_scaling = ScalingFactors(\n        dims=jnp.ones((4, 3)) * 2.0,\n        density=jnp.ones(4),\n    )\n    scaled_uniform = js.model.update_hw_parameters(model, uniform_scaling)\n\n    for i in range(scaled_uniform.number_of_links()):\n        link_name = js.link.idx_to_name(model=scaled_uniform, link_index=i)\n        mass_scaled = float(scaled_uniform.kin_dyn_parameters.link_parameters.mass[i])\n        ratio = mass_scaled / masses_orig[link_name]\n        assert jnp.allclose(\n            ratio, 8.0, rtol=0.1\n        ), f\"Uniform scaling: {link_name} expected 8x, got {ratio:.2f}x\"\n\n    # Test different scaling factors per link\n    different_scaling = ScalingFactors(\n        dims=jnp.array(\n            [\n                [2.0, 2.0, 2.0],  # box: 8x\n                [3.0, 3.0, 3.0],  # cylinder: 27x\n                [1.5, 1.5, 1.5],  # mesh: 3.375x\n                [2.5, 2.5, 2.5],  # sphere: 15.625x\n            ]\n        ),\n        density=jnp.ones(4),\n    )\n    scaled_different = js.model.update_hw_parameters(model, different_scaling)\n\n    expected_ratios = {\n        \"box_link\": 8.0,\n        \"cylinder_link\": 27.0,\n        \"mesh_link\": 3.375,\n        \"sphere_link\": 15.625,\n    }\n\n    for i in range(scaled_different.number_of_links()):\n        link_name = js.link.idx_to_name(model=scaled_different, link_index=i)\n        mass_scaled = float(scaled_different.kin_dyn_parameters.link_parameters.mass[i])\n        ratio = mass_scaled / masses_orig[link_name]\n        expected = expected_ratios[link_name]\n        assert jnp.allclose(\n            ratio, expected, rtol=0.1\n        ), f\"Different scaling: {link_name} expected {expected}x, got {ratio:.2f}x\"\n"
  },
  {
    "path": "tests/test_automatic_differentiation.py",
    "content": "import os\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom jax.test_util import check_grads\n\nimport jaxsim.api as js\nimport jaxsim.math\nimport jaxsim.rbda\nimport jaxsim.typing as jtp\nfrom jaxsim import VelRepr\nfrom jaxsim.rbda.contacts import SoftContacts, SoftContactsParams\n\nfrom .utils import assert_allclose\n\n# All JaxSim algorithms, excluding the variable-step integrators, should support\n# being automatically differentiated until second order, both in FWD and REV modes.\n# However, checking the second-order derivatives is particularly slow and makes\n# CI tests take too long. Therefore, we only check first-order derivatives.\nAD_ORDER = os.environ.get(\"JAXSIM_TEST_AD_ORDER\", 1)\n\n# Define the step size used to compute finite differences depending on the\n# floating point resolution.\nε = os.environ.get(\n    \"JAXSIM_TEST_FD_STEP_SIZE\",\n    jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3),\n)\n\n\ndef get_random_data_and_references(\n    model: js.model.JaxSimModel,\n    velocity_representation: VelRepr,\n    key: jax.Array,\n) -> tuple[js.data.JaxSimModelData, js.references.JaxSimModelReferences]:\n\n    key, subkey = jax.random.split(key, num=2)\n\n    data = js.data.random_model_data(\n        model=model, key=subkey, velocity_representation=velocity_representation\n    )\n\n    _, subkey1, subkey2 = jax.random.split(key, num=3)\n\n    references = js.references.JaxSimModelReferences.build(\n        model=model,\n        joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)),\n        link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)),\n        data=data,\n        velocity_representation=velocity_representation,\n    )\n\n    # Remove the force applied to the base link if the model is fixed-base.\n    if not model.floating_base():\n        references = references.apply_link_forces(\n            forces=jnp.atleast_2d(jnp.zeros(6)),\n            model=model,\n            data=data,\n            link_names=(model.base_link(),),\n            additive=False,\n        )\n\n    return data, references\n\n\ndef test_ad_aba(\n    jaxsim_models_types: js.model.JaxSimModel,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data, references = get_random_data_and_references(\n        model=model, velocity_representation=VelRepr.Inertial, key=subkey\n    )\n\n    # Get the standard gravity constant.\n    g = jaxsim.math.STANDARD_GRAVITY\n\n    # State in VelRepr.Inertial representation.\n    W_p_B = data.base_position\n    W_Q_B = data.base_orientation\n    s = data.joint_positions\n    W_v_WB = data.base_velocity\n    ṡ = data.joint_velocities\n\n    # Inputs.\n    W_f_L = references.link_forces(model=model)\n    τ = references.joint_force_references(model=model)\n\n    # ====\n    # Test\n    # ====\n\n    # Get a closure exposing only the parameters to be differentiated.\n    def aba(W_p_B, W_Q_B, s, W_v_WB, ṡ, τ, W_f_L, g):\n        import jaxlie\n\n        W_H_B = jaxlie.SE3.from_rotation_and_translation(\n            rotation=jaxlie.SO3(wxyz=W_Q_B / jnp.linalg.norm(W_Q_B)),\n            translation=W_p_B,\n        ).as_matrix()\n        joint_transforms = model.kin_dyn_parameters.joint_transforms(\n            joint_positions=s, base_transform=W_H_B\n        )\n        return jaxsim.rbda.aba(\n            model=model,\n            base_position=W_p_B,\n            base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),\n            joint_positions=s,\n            base_linear_velocity=W_v_WB[0:3],\n            base_angular_velocity=W_v_WB[3:6],\n            joint_velocities=ṡ,\n            joint_transforms=joint_transforms,\n            joint_forces=τ,\n            link_forces=W_f_L,\n            standard_gravity=g,\n        )\n\n    # Check derivatives against finite differences.\n    check_grads(\n        f=aba,\n        args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, τ, W_f_L, g),\n        order=AD_ORDER,\n        modes=[\"rev\", \"fwd\"],\n        eps=ε,\n    )\n\n\ndef test_ad_aba_parallel(\n    jaxsim_models_types: js.model.JaxSimModel,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data, references = get_random_data_and_references(\n        model=model, velocity_representation=VelRepr.Inertial, key=subkey\n    )\n\n    g = jaxsim.math.STANDARD_GRAVITY\n\n    W_p_B = data.base_position\n    W_Q_B = data.base_orientation\n    s = data.joint_positions\n    W_v_WB = data.base_velocity\n    ṡ = data.joint_velocities\n    i_X_λi = data._joint_transforms\n\n    W_f_L = references.link_forces(model=model)\n    τ = references.joint_force_references(model=model)\n\n    # Verify parallel ABA matches sequential ABA.\n    result_seq = jaxsim.rbda.aba(\n        model=model,\n        base_position=W_p_B,\n        base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),\n        joint_positions=s,\n        base_linear_velocity=W_v_WB[0:3],\n        base_angular_velocity=W_v_WB[3:6],\n        joint_velocities=ṡ,\n        joint_transforms=i_X_λi,\n        joint_forces=τ,\n        link_forces=W_f_L,\n        standard_gravity=g,\n    )\n\n    result_par = jaxsim.rbda.aba_parallel(\n        model=model,\n        base_position=W_p_B,\n        base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),\n        joint_positions=s,\n        base_linear_velocity=W_v_WB[0:3],\n        base_angular_velocity=W_v_WB[3:6],\n        joint_velocities=ṡ,\n        joint_forces=τ,\n        joint_transforms=i_X_λi,\n        link_forces=W_f_L,\n        standard_gravity=g,\n    )\n\n    assert_allclose(result_seq[0], result_par[0], atol=1e-10)\n    assert_allclose(result_seq[1], result_par[1], atol=1e-10)\n\n    # Check derivatives against finite differences.\n    aba_par = lambda W_p_B, W_Q_B, s, W_v_WB, ṡ, i_X_λi, τ, W_f_L, g: jaxsim.rbda.aba_parallel(\n        model=model,\n        base_position=W_p_B,\n        base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),\n        joint_positions=s,\n        base_linear_velocity=W_v_WB[0:3],\n        base_angular_velocity=W_v_WB[3:6],\n        joint_velocities=ṡ,\n        joint_forces=τ,\n        joint_transforms=i_X_λi,\n        link_forces=W_f_L,\n        standard_gravity=g,\n    )\n\n    check_grads(\n        f=aba_par,\n        args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, i_X_λi, τ, W_f_L, g),\n        order=AD_ORDER,\n        modes=[\"rev\", \"fwd\"],\n        eps=ε,\n    )\n\n\ndef test_ad_rnea(\n    jaxsim_models_types: js.model.JaxSimModel,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    key, subkey = jax.random.split(prng_key, num=2)\n    data, references = get_random_data_and_references(\n        model=model, velocity_representation=VelRepr.Inertial, key=subkey\n    )\n\n    # Get the standard gravity constant.\n    g = jaxsim.math.STANDARD_GRAVITY\n\n    # State in VelRepr.Inertial representation.\n    W_p_B = data.base_position\n    W_Q_B = data.base_orientation\n    s = data.joint_positions\n    W_v_WB = data.base_velocity\n    ṡ = data.joint_velocities\n    i_X_λi = data._joint_transforms\n\n    # Inputs.\n    W_f_L = references.link_forces(model=model)\n\n    # ====\n    # Test\n    # ====\n\n    _, subkey1, subkey2 = jax.random.split(key, num=3)\n    W_v̇_WB = jax.random.uniform(subkey1, shape=(6,), minval=-1)\n    s̈ = jax.random.uniform(subkey2, shape=(model.dofs(),), minval=-1)\n\n    # Get a closure exposing only the parameters to be differentiated.\n    rnea = lambda W_p_B, W_Q_B, s, W_v_WB, ṡ, W_v̇_WB, s̈, i_X_λi, W_f_L, g: jaxsim.rbda.rnea(\n        model=model,\n        base_position=W_p_B,\n        base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),\n        joint_positions=s,\n        base_linear_velocity=W_v_WB[0:3],\n        base_angular_velocity=W_v_WB[3:6],\n        joint_velocities=ṡ,\n        base_linear_acceleration=W_v̇_WB[0:3],\n        base_angular_acceleration=W_v̇_WB[3:6],\n        joint_accelerations=s̈,\n        joint_transforms=i_X_λi,\n        link_forces=W_f_L,\n        standard_gravity=g,\n    )\n\n    # Check derivatives against finite differences.\n    check_grads(\n        f=rnea,\n        args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, W_v̇_WB, s̈, i_X_λi, W_f_L, g),\n        order=AD_ORDER,\n        modes=[\"rev\", \"fwd\"],\n        eps=ε,\n    )\n\n\ndef test_ad_crba(\n    jaxsim_models_types: js.model.JaxSimModel,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data, _ = get_random_data_and_references(\n        model=model, velocity_representation=VelRepr.Inertial, key=subkey\n    )\n\n    # State in VelRepr.Inertial representation.\n    s = data.joint_positions\n\n    # ====\n    # Test\n    # ====\n\n    # Get a closure exposing only the parameters to be differentiated.\n    crba = lambda s: jaxsim.rbda.crba(model=model, joint_positions=s)\n\n    # Check derivatives against finite differences.\n    check_grads(\n        f=crba,\n        args=(s,),\n        order=AD_ORDER,\n        modes=[\"rev\", \"fwd\"],\n        eps=ε,\n    )\n\n\ndef test_ad_fk(\n    jaxsim_models_types: js.model.JaxSimModel,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data, _ = get_random_data_and_references(\n        model=model, velocity_representation=VelRepr.Inertial, key=subkey\n    )\n\n    # State in VelRepr.Inertial representation.\n    W_p_B = data.base_position\n    W_Q_B = data.base_orientation\n    s = data.joint_positions\n    W_v_lin = data._base_linear_velocity\n    W_v_ang = data._base_angular_velocity\n    ṡ = data.joint_velocities\n\n    # ====\n    # Test\n    # ====\n\n    # Get a closure exposing only the parameters to be differentiated.\n    def fk(W_p_B, W_Q_B, s, W_v_lin, W_v_ang, ṡ):\n        import jaxlie\n\n        W_H_B = jaxlie.SE3.from_rotation_and_translation(\n            rotation=jaxlie.SO3(wxyz=W_Q_B / jnp.linalg.norm(W_Q_B)),\n            translation=W_p_B,\n        ).as_matrix()\n        joint_transforms = model.kin_dyn_parameters.joint_transforms(\n            joint_positions=s, base_transform=W_H_B\n        )\n        return jaxsim.rbda.forward_kinematics_model(\n            model=model,\n            base_position=W_p_B,\n            base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),\n            joint_positions=s,\n            base_linear_velocity_inertial=W_v_lin,\n            base_angular_velocity_inertial=W_v_ang,\n            joint_velocities=ṡ,\n            joint_transforms=joint_transforms,\n        )\n\n    # Check derivatives against finite differences.\n    check_grads(\n        f=fk,\n        args=(W_p_B, W_Q_B, s, W_v_lin, W_v_ang, ṡ),\n        order=AD_ORDER,\n        modes=[\"rev\", \"fwd\"],\n        eps=ε,\n    )\n\n\ndef test_ad_jacobian(\n    jaxsim_models_types: js.model.JaxSimModel,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data, _ = get_random_data_and_references(\n        model=model, velocity_representation=VelRepr.Inertial, key=subkey\n    )\n\n    # State in VelRepr.Inertial representation.\n    s = data.joint_positions\n\n    # ====\n    # Test\n    # ====\n\n    # Get the link indices.\n    link_indices = jnp.arange(model.number_of_links())\n\n    # Get a closure exposing only the parameters to be differentiated.\n    # We differentiate the jacobian of the last link, likely among those\n    # farther from the base.\n    jacobian = lambda s: jaxsim.rbda.jacobian(\n        model=model, joint_positions=s, link_index=link_indices[-1]\n    )\n\n    # Check derivatives against finite differences.\n    check_grads(\n        f=jacobian,\n        args=(s,),\n        order=AD_ORDER,\n        modes=[\"rev\", \"fwd\"],\n        eps=ε,\n    )\n\n\ndef test_ad_soft_contacts(\n    jaxsim_models_types: js.model.JaxSimModel,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    with model.editable(validate=False) as model:\n        model.contact_model = jaxsim.rbda.contacts.SoftContacts.build()\n\n    _, subkey1, subkey2, subkey3 = jax.random.split(prng_key, num=4)\n    p = jax.random.uniform(subkey1, shape=(3,), minval=-1)\n    v = jax.random.uniform(subkey2, shape=(3,), minval=-1)\n    m = jax.random.uniform(subkey3, shape=(3,), minval=-1)\n\n    # Get the soft contacts parameters.\n    parameters = js.contact.estimate_good_contact_parameters(model=model)\n\n    # ====\n    # Test\n    # ====\n\n    # Get a closure exposing only the parameters to be differentiated.\n    def close_over_inputs_and_parameters(\n        p: jtp.VectorLike,\n        v: jtp.VectorLike,\n        m: jtp.VectorLike,\n        params: SoftContactsParams,\n    ) -> tuple[jtp.Vector, jtp.Vector]:\n\n        W_f_Ci, CW_ṁ = SoftContacts.compute_contact_force(\n            position=p,\n            velocity=v,\n            tangential_deformation=m,\n            parameters=params,\n            terrain=model.terrain,\n        )\n\n        return W_f_Ci, CW_ṁ\n\n    # Check derivatives against finite differences.\n    check_grads(\n        f=close_over_inputs_and_parameters,\n        args=(p, v, m, parameters),\n        order=AD_ORDER,\n        modes=[\"rev\", \"fwd\"],\n        eps=ε,\n        # On GPU, the tolerance needs to be increased.\n        rtol=0.02 if \"gpu\" in {d.platform for d in p.devices()} else None,\n    )\n\n\ndef test_ad_integration(\n    jaxsim_models_types: js.model.JaxSimModel,\n    prng_key: jax.Array,\n):\n\n    model = jaxsim_models_types\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    data, references = get_random_data_and_references(\n        model=model, velocity_representation=VelRepr.Inertial, key=subkey\n    )\n\n    # State in VelRepr.Inertial representation.\n    W_p_B = data.base_position\n    W_Q_B = data.base_orientation\n    s = data.joint_positions\n    W_v_WB = data.base_velocity\n    ṡ = data.joint_velocities\n\n    # Inputs.\n    W_f_L = references.link_forces(model=model)\n    τ = references.joint_force_references(model=model)\n\n    # ====\n    # Test\n    # ====\n\n    # Function exposing only the parameters to be differentiated.\n    def step(\n        W_p_B: jtp.Vector,\n        W_Q_B: jtp.Vector,\n        s: jtp.Vector,\n        W_v_WB: jtp.Vector,\n        ṡ: jtp.Vector,\n        τ: jtp.Vector,\n        W_f_L: jtp.Matrix,\n    ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]:\n\n        # When JAX tests against finite differences, the injected ε will make the\n        # quaternion non-unitary, which will cause the AD check to fail.\n        W_Q_B = W_Q_B / jnp.linalg.norm(W_Q_B)\n\n        data_x0 = data.replace(\n            model=model,\n            base_position=W_p_B,\n            base_quaternion=W_Q_B,\n            joint_positions=s,\n            base_linear_velocity=W_v_WB[0:3],\n            base_angular_velocity=W_v_WB[3:6],\n            joint_velocities=ṡ,\n        )\n\n        data_xf = js.model.step(\n            model=model,\n            data=data_x0,\n            joint_force_references=τ,\n            link_forces=W_f_L,\n        )\n\n        xf_W_p_B = data_xf.base_position\n        xf_W_Q_B = data_xf.base_orientation\n        xf_s = data_xf.joint_positions\n        xf_W_v_WB = data_xf.base_velocity\n        xf_ṡ = data_xf.joint_velocities\n\n        return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ\n\n    # Check derivatives against finite differences.\n    check_grads(\n        f=step,\n        args=(W_p_B, W_Q_B, s, W_v_WB, ṡ, τ, W_f_L),\n        order=AD_ORDER,\n        modes=[\"fwd\", \"rev\"],\n        eps=ε,\n    )\n\n\ndef test_ad_safe_norm(\n    prng_key: jax.Array,\n):\n\n    _, subkey = jax.random.split(prng_key, num=2)\n    array = jax.random.uniform(subkey, shape=(4,), minval=-5, maxval=5)\n\n    # ====\n    # Test\n    # ====\n\n    # Test that the safe_norm function is compatible with batching.\n    array = jnp.stack([array, array])\n    assert jaxsim.math.safe_norm(array, axis=-1).shape == (2,)\n\n    # Test that the safe_norm function is correctly computing the norm.\n    assert_allclose(\n        jaxsim.math.safe_norm(array, axis=-1), np.linalg.norm(array, axis=-1)\n    )\n\n    # Function exposing only the parameters to be differentiated.\n    def safe_norm(array: jtp.Array) -> jtp.Array:\n\n        return jaxsim.math.safe_norm(array, axis=-1)\n\n    # Check derivatives against finite differences.\n    check_grads(\n        f=safe_norm,\n        args=(array,),\n        order=AD_ORDER,\n        modes=[\"rev\", \"fwd\"],\n        eps=ε,\n    )\n\n    # Check derivatives against finite differences when the array is zero.\n    check_grads(\n        f=safe_norm,\n        args=(jnp.zeros_like(array),),\n        order=AD_ORDER,\n        modes=[\"rev\", \"fwd\"],\n        eps=ε,\n    )\n\n\ndef test_ad_hw_parameters(\n    jaxsim_model_garpez: js.model.JaxSimModel,\n    prng_key: jax.Array,\n):\n    \"\"\"\n    Test the automatic differentiation capability for hardware parameters of the model links.\n    \"\"\"\n\n    model = jaxsim_model_garpez\n    data = js.data.JaxSimModelData.build(model=model)\n\n    min_val, max_val = 0.5, 10.0\n\n    # Generate random scaling factors for testing.\n    _, subkey1, subkey2 = jax.random.split(prng_key, num=3)\n    dims_scaling = jax.random.uniform(\n        subkey1, shape=(model.number_of_links(), 3), minval=min_val, maxval=max_val\n    )\n    density_scaling = jax.random.uniform(\n        subkey2, shape=(model.number_of_links(),), minval=min_val, maxval=max_val\n    )\n\n    scaling_factors = js.kin_dyn_parameters.ScalingFactors(\n        dims=dims_scaling, density=density_scaling\n    )\n\n    link_idx = js.link.name_to_idx(model, link_name=\"link4\")\n\n    # Define a function that updates hardware parameters and computes FK for link 4.\n    def update_hw_params_and_compute_fk_and_mass(\n        scaling_factors: js.kin_dyn_parameters.ScalingFactors,\n    ):\n        # Update hardware parameters.\n        updated_model = js.model.update_hw_parameters(\n            model=model, scaling_factors=scaling_factors\n        )\n\n        # Compute forward kinematics for link 4.\n        W_H_L4 = js.model.forward_kinematics(model=updated_model, data=data)[link_idx]\n\n        # Compute the free floating mass matrix of the updated model.\n        M = js.model.free_floating_mass_matrix(updated_model, data)\n\n        return W_H_L4[:3, 3], M\n\n    # Check derivatives against finite differences.\n    check_grads(\n        f=update_hw_params_and_compute_fk_and_mass,\n        args=(scaling_factors,),\n        order=AD_ORDER,\n        modes=[\"fwd\", \"rev\"],\n        eps=ε,\n    )\n"
  },
  {
    "path": "tests/test_benchmark.py",
    "content": "from collections.abc import Callable\n\nimport jax\nimport jax.numpy as jnp\nimport pytest\n\nimport jaxsim\nimport jaxsim.api as js\nfrom jaxsim.api.kin_dyn_parameters import ScalingFactors\n\n\ndef vectorize_data(model: js.model.JaxSimModel, batch_size: int):\n    key = jax.random.PRNGKey(seed=0)\n    keys = jax.random.split(key, num=batch_size)\n\n    return jax.vmap(\n        lambda key: js.data.random_model_data(\n            model=model,\n            key=key,\n        )\n    )(keys)\n\n\ndef benchmark_test_function(\n    func: Callable, model: js.model.JaxSimModel, benchmark, batch_size\n):\n    \"\"\"Reusability wrapper for benchmark tests.\"\"\"\n    data = vectorize_data(model=model, batch_size=batch_size)\n\n    # Warm-up call to avoid including compilation time\n    jax.vmap(func, in_axes=(None, 0))(model, data)\n\n    # Benchmark the function call\n    # Note: jax.block_until_ready is used to ensure that the benchmark is not measuring only the asynchronous dispatch\n    benchmark(jax.block_until_ready(jax.vmap(func, in_axes=(None, 0))), model, data)\n\n\n@pytest.mark.benchmark\ndef test_forward_dynamics_aba(\n    jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size\n):\n    model = jaxsim_model_ergocub_reduced\n\n    benchmark_test_function(js.model.forward_dynamics_aba, model, benchmark, batch_size)\n\n\n@pytest.mark.benchmark\ndef test_free_floating_bias_forces(\n    jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size\n):\n    model = jaxsim_model_ergocub_reduced\n\n    benchmark_test_function(\n        js.model.free_floating_bias_forces, model, benchmark, batch_size\n    )\n\n\n@pytest.mark.benchmark\ndef test_forward_kinematics(\n    jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size\n):\n    model = jaxsim_model_ergocub_reduced\n\n    benchmark_test_function(js.model.forward_kinematics, model, benchmark, batch_size)\n\n\n@pytest.mark.benchmark\ndef test_free_floating_mass_matrix(\n    jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size\n):\n    model = jaxsim_model_ergocub_reduced\n\n    benchmark_test_function(\n        js.model.free_floating_mass_matrix, model, benchmark, batch_size\n    )\n\n\n@pytest.mark.benchmark\ndef test_free_floating_jacobian(\n    jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size\n):\n    model = jaxsim_model_ergocub_reduced\n\n    benchmark_test_function(\n        js.model.generalized_free_floating_jacobian, model, benchmark, batch_size\n    )\n\n\n@pytest.mark.benchmark\ndef test_free_floating_jacobian_derivative(\n    jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size\n):\n    model = jaxsim_model_ergocub_reduced\n\n    benchmark_test_function(\n        js.model.generalized_free_floating_jacobian_derivative,\n        model,\n        benchmark,\n        batch_size,\n    )\n\n\n@pytest.mark.benchmark\ndef test_soft_contact_model(\n    jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size\n):\n    model = jaxsim_model_ergocub_reduced\n\n    with model.editable(validate=False) as model:\n        model.contact_model = jaxsim.rbda.contacts.SoftContacts()\n        model.contact_params = js.contact.estimate_good_contact_parameters(model=model)\n\n    benchmark_test_function(js.ode.system_dynamics, model, benchmark, batch_size)\n\n\n@pytest.mark.benchmark\ndef test_rigid_contact_model(\n    jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size\n):\n    model = jaxsim_model_ergocub_reduced\n\n    with model.editable(validate=False) as model:\n        model.contact_model = jaxsim.rbda.contacts.RigidContacts()\n        model.contact_params = js.contact.estimate_good_contact_parameters(model=model)\n\n    benchmark_test_function(js.ode.system_dynamics, model, benchmark, batch_size)\n\n\n@pytest.mark.benchmark\ndef test_relaxed_rigid_contact_model(\n    jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size\n):\n    model = jaxsim_model_ergocub_reduced\n\n    with model.editable(validate=False) as model:\n        model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts()\n        model.contact_params = js.contact.estimate_good_contact_parameters(model=model)\n\n    benchmark_test_function(js.ode.system_dynamics, model, benchmark, batch_size)\n\n\n@pytest.mark.benchmark\ndef test_simulation_step(\n    jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size\n):\n    model = jaxsim_model_ergocub_reduced\n\n    with model.editable(validate=False) as model:\n        model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts()\n        model.contact_params = js.contact.estimate_good_contact_parameters(model=model)\n\n    benchmark_test_function(js.model.step, model, benchmark, batch_size)\n\n\n@pytest.mark.benchmark\ndef test_update_hw_parameters(\n    jaxsim_model_garpez: js.model.JaxSimModel, benchmark, batch_size\n):\n    \"\"\"Benchmark hardware parameter scaling/update operation (vmapped).\"\"\"\n    model = jaxsim_model_garpez\n    n_links = model.number_of_links()\n\n    # Create a function that generates random scaling factors and updates the model\n    def update_with_random_scaling(key: jax.Array) -> js.model.JaxSimModel:\n        # Generate scaling factors in a reasonable range [0.8, 1.2]\n        dims_scale = jax.random.uniform(key, shape=(n_links, 3), minval=0.8, maxval=1.2)\n        density_scale = jax.random.uniform(\n            jax.random.fold_in(key, 1), shape=(n_links,), minval=0.8, maxval=1.2\n        )\n        scaling_factors = ScalingFactors(dims=dims_scale, density=density_scale)\n        return js.model.update_hw_parameters(model, scaling_factors)\n\n    # Generate batch of random keys\n    key = jax.random.PRNGKey(seed=42)\n    keys = jax.random.split(key, num=batch_size)\n\n    # Warm-up call to avoid including compilation time\n    jax.vmap(update_with_random_scaling)(keys)\n\n    # Benchmark the vmapped update operation\n    benchmark(jax.block_until_ready(jax.vmap(update_with_random_scaling)), keys)\n\n\n@pytest.mark.benchmark\ndef test_export_updated_model(\n    jaxsim_model_garpez: js.model.JaxSimModel, benchmark, batch_size\n):\n    \"\"\"Benchmark model export after hardware parameter update.\"\"\"\n    model = jaxsim_model_garpez\n    n_links = model.number_of_links()\n\n    # Create multiple scaled models for benchmarking\n    # Each with slightly different scaling to simulate realistic scenarios\n    key = jax.random.PRNGKey(seed=42)\n    scaling_values = jax.random.uniform(\n        key, shape=(batch_size,), minval=0.9, maxval=1.2\n    )\n\n    updated_models = []\n    for scale_value in scaling_values:\n        scaling_factors = ScalingFactors(\n            dims=jnp.ones((n_links, 3)) * float(scale_value),\n            density=jnp.ones(n_links),\n        )\n        updated_models.append(js.model.update_hw_parameters(model, scaling_factors))\n\n    # Benchmark the export operation (sequentially for all models)\n    # Note: This is not JIT-compiled since it returns a string (URDF/SDF)\n    def export_all():\n        return [m.export_updated_model() for m in updated_models]\n\n    benchmark(export_all)\n"
  },
  {
    "path": "tests/test_exceptions.py",
    "content": "import io\nfrom contextlib import redirect_stdout\n\nimport chex\nimport jax\nimport jax.numpy as jnp\nimport pytest\nfrom jax.errors import JaxRuntimeError\n\nfrom jaxsim import exceptions\n\n\ndef test_exceptions_in_jit_functions():\n\n    msg_during_jit = \"Compiling jit_compiled_function\"\n\n    @jax.jit\n    @chex.assert_max_traces(n=1)\n    def jit_compiled_function(data: jax.Array) -> jax.Array:\n\n        # This message is compiled only during JIT compilation.\n        print(msg_during_jit)\n\n        # Condition that will trigger the exception.\n        failed_if_42_plus = jnp.allclose(data, 42)\n\n        # Raise a ValueError if the condition is met.\n        # The fmt string is built from kwargs.\n        exceptions.raise_value_error_if(\n            condition=failed_if_42_plus,\n            msg=\"Raising ValueError since data={num}\",\n            num=data,\n        )\n\n        # Condition that will trigger the exception.\n        failed_if_42_minus = jnp.allclose(data, -42)\n\n        # Raise a RuntimeError if the condition is met.\n        # The fmt string is built from args.\n        exceptions.raise_runtime_error_if(\n            failed_if_42_minus,\n            \"Raising RuntimeError since data={}\",\n            data,\n        )\n\n        return data\n\n    # In the first call, the function will be compiled and print the message.\n    with jax.log_compiles(), io.StringIO() as buf, redirect_stdout(buf):\n\n        data = 40\n        out = jit_compiled_function(data=data)\n        stdout = buf.getvalue()\n        assert out == data\n\n    assert msg_during_jit in stdout\n\n    # In the second call, the function won't be compiled and won't print the message.\n    with jax.log_compiles(), io.StringIO() as buf, redirect_stdout(buf):\n\n        data = 41\n        out = jit_compiled_function(data=data)\n        stdout = buf.getvalue()\n        assert out == data\n\n    assert msg_during_jit not in stdout\n\n    # Let's trigger a ValueError exception by passing 42.\n    data = 42\n    with pytest.raises(\n        JaxRuntimeError,\n        match=f\"ValueError: Raising ValueError since data={data}\",\n    ):\n        _ = jit_compiled_function(data=data)\n\n    # Let's trigger a RuntimeError exception by passing -42.\n    data = -42\n    with pytest.raises(\n        JaxRuntimeError,\n        match=f\"RuntimeError: Raising RuntimeError since data={data}\",\n    ):\n        _ = jit_compiled_function(data=data)\n"
  },
  {
    "path": "tests/test_meshes.py",
    "content": "import trimesh\n\nfrom jaxsim.parsers.rod import meshes\n\n\ndef test_mesh_wrapping_vertex_extraction():\n    \"\"\"\n    Test the vertex extraction method on different meshes.\n\n    1. A simple box.\n    2. A sphere.\n    \"\"\"\n\n    # Test 1: A simple box.\n    #     First, create a box with origin at (0,0,0) and extents (3,3,3),\n    #     i.e. points span from -1.5 to 1.5 on the axis.\n    mesh = trimesh.creation.box(\n        extents=[3.0, 3.0, 3.0],\n    )\n    points = meshes.extract_points_vertices(mesh=mesh)\n    assert len(points) == len(mesh.vertices)\n\n    # Test 2: A sphere.\n    #     The sphere is centered at the origin and has a radius of 1.0.\n    mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0)\n    points = meshes.extract_points_vertices(mesh=mesh)\n    assert len(points) == len(mesh.vertices)\n\n\ndef test_mesh_wrapping_aap():\n    \"\"\"\n    Test the AAP wrapping method on different meshes.\n\n    1. A simple box\n        1.1: Remove all points above x=0.0\n        1.2: Remove all points below y=0.0\n    2. A sphere\n    \"\"\"\n\n    # Test 1.1: Remove all points above x=0.0.\n    #     The expected result is that the number of points is halved.\n    #     First, create a box with origin at (0,0,0) and extents (3,3,3),\n    #     i.e. points span from -1.5 to 1.5 on the axis.\n    mesh = trimesh.creation.box(extents=[3.0, 3.0, 3.0])\n    points = meshes.extract_points_aap(mesh=mesh, axis=\"x\", lower=0.0)\n    assert len(points) == len(mesh.vertices) // 2\n    assert all(points[:, 0] > 0.0)\n\n    # Test 1.2: Remove all points below y=0.0.\n    #     The expected result is that the number of points is halved.\n    points = meshes.extract_points_aap(mesh=mesh, axis=\"y\", upper=0.0)\n    assert len(points) == len(mesh.vertices) // 2\n    assert all(points[:, 1] < 0.0)\n\n    # Test 2: A sphere.\n    #     The sphere is centered at the origin and has a radius of 1.0.\n    #     Points are expected to be halved.\n    mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0)\n\n    # Remove all points above y=0.0.\n    points = meshes.extract_points_aap(mesh=mesh, axis=\"y\", lower=0.0)\n    assert all(points[:, 1] >= 0.0)\n    assert len(points) < len(mesh.vertices)\n\n\ndef test_mesh_wrapping_points_over_axis():\n    \"\"\"\n    Test the points over axis method on different meshes.\n\n    1. A simple box\n        1.1: Select 10 points from the lower end of the x-axis\n        1.2: Select 10 points from the higher end of the y-axis\n    2. A sphere\n    \"\"\"\n\n    # Test 1.1: Remove 10 points from the lower end of the x-axis.\n    #     First, create a box with origin at (0,0,0) and extents (3,3,3),\n    #     i.e. points span from -1.5 to 1.5 on the axis.\n    mesh = trimesh.creation.box(extents=[3.0, 3.0, 3.0])\n    points = meshes.extract_points_select_points_over_axis(\n        mesh=mesh, axis=\"x\", direction=\"lower\", n=4\n    )\n    assert len(points) == 4\n    assert all(points[:, 0] < 0.0)\n\n    # Test 1.2: Select 10 points from the higher end of the y-axis.\n    points = meshes.extract_points_select_points_over_axis(\n        mesh=mesh, axis=\"y\", direction=\"higher\", n=4\n    )\n    assert len(points) == 4\n    assert all(points[:, 1] > 0.0)\n\n    # Test 2: A sphere.\n    #     The sphere is centered at the origin and has a radius of 1.0.\n    mesh = trimesh.creation.icosphere(subdivisions=4, radius=1.0)\n    sphere_n_vertices = len(mesh.vertices)\n\n    # Select 10 points from the higher end of the z-axis.\n    points = meshes.extract_points_select_points_over_axis(\n        mesh=mesh, axis=\"z\", direction=\"higher\", n=sphere_n_vertices // 2\n    )\n    assert len(points) == sphere_n_vertices // 2\n    assert all(points[:, 2] >= 0.0)\n"
  },
  {
    "path": "tests/test_pytree.py",
    "content": "import io\nimport pathlib\nfrom contextlib import redirect_stdout\n\nimport chex\nimport jax\nimport jax.numpy as jnp\nimport pytest\n\nimport jaxsim.api as js\n\n\ndef test_call_jit_compiled_function_passing_different_objects(\n    ergocub_model_description_path: pathlib.Path, jaxsim_model_box\n):\n\n    # Create a first model from the URDF.\n    ergocub_model1 = js.model.JaxSimModel.build_from_model_description(\n        model_description=ergocub_model_description_path\n    )\n\n    # Create a second model from the URDF.\n    ergocub_model2 = js.model.JaxSimModel.build_from_model_description(\n        model_description=ergocub_model_description_path\n    )\n\n    box_model = jaxsim_model_box\n\n    # The objects should be different, but the comparison should return True.\n    assert id(ergocub_model1) != id(ergocub_model2)\n    assert ergocub_model1 == ergocub_model2\n    assert hash(ergocub_model1) == hash(ergocub_model2)\n\n    # If this function has never been compiled by any other test, JAX will\n    # jit-compile it here.\n    _ = js.contact.estimate_good_contact_parameters(model=ergocub_model1)\n\n    # Now JAX should not compile it again.\n    with jax.log_compiles(), io.StringIO() as buf, redirect_stdout(buf):\n        # Beyond running without any JIT recompilations, the following function\n        # should work on different JaxSimModel objects without raising any errors\n        # related to the comparison of Static fields.\n        _ = js.contact.estimate_good_contact_parameters(model=ergocub_model2)\n        stdout = buf.getvalue()\n\n    assert (\n        f\"Compiling {js.contact.estimate_good_contact_parameters.__name__}\"\n        not in stdout\n    )\n\n    # Define a new JIT-compiled function and check that is not recompiled for\n    # different model objects having the same pytree structure.\n    @jax.jit\n    @chex.assert_max_traces(n=1)\n    def my_jit_function(model: js.model.JaxSimModel, data: js.data.JaxSimModelData):\n        # Return random elements from model and data, just to have something returned.\n        return (\n            jnp.sum(model.kin_dyn_parameters.link_parameters.mass),\n            data.base_position,\n        )\n\n    data1 = js.data.JaxSimModelData.build(model=ergocub_model1)\n\n    _ = my_jit_function(model=ergocub_model1, data=data1)\n\n    # This should not retrace the function, as ergocub_model2 has the same\n    # pytree structure as ergocub_model1.\n    _ = my_jit_function(model=ergocub_model2, data=data1)\n\n    # Calling the function with a different model object will retrace it, as\n    # expected. Therefore, an AssertionError should be raised.\n    with pytest.raises(\n        AssertionError, match=\"Function 'my_jit_function' is traced > 1 times!\"\n    ):\n        data3 = js.data.JaxSimModelData.build(model=box_model)\n        _ = my_jit_function(model=box_model, data=data3)\n"
  },
  {
    "path": "tests/test_simulations.py",
    "content": "import jax\nimport jax.numpy as jnp\nimport numpy as np\nimport pytest\n\nimport jaxsim.api as js\nimport jaxsim.rbda\nimport jaxsim.typing as jtp\nfrom jaxsim import VelRepr\nfrom jaxsim.api.kin_dyn_parameters import ConstraintType\n\nfrom .utils import assert_allclose\n\n\ndef test_box_with_external_forces(\n    jaxsim_model_box: js.model.JaxSimModel,\n    velocity_representation: VelRepr,\n):\n    \"\"\"\n    Simulate a box falling due to gravity.\n\n    We apply to its CoM a 6D force that balances exactly the gravitational force.\n    The box should not fall.\n    \"\"\"\n\n    model = jaxsim_model_box\n\n    # Build the data of the model.\n    data0 = js.data.JaxSimModelData.build(\n        model=model,\n        base_position=jnp.array([0.0, 0.0, 0.5]),\n        velocity_representation=velocity_representation,\n    )\n\n    # Compute the force due to gravity at the CoM.\n    mg = -model.gravity * js.model.total_mass(model=model)\n    G_f = jnp.array([0.0, 0.0, mg, 0, 0, 0])\n\n    # Compute the position of the CoM expressed in the coordinates of the link frame L.\n    L_p_CoM = js.link.com_position(\n        model=model, data=data0, link_index=0, in_link_frame=True\n    )\n\n    # Compute the transform of 6D forces from the CoM to the link frame.\n    L_H_G = jaxsim.math.Transform.from_quaternion_and_translation(translation=L_p_CoM)\n    G_Xv_L = jaxsim.math.Adjoint.from_transform(transform=L_H_G, inverse=True)\n    L_Xf_G = G_Xv_L.T\n    L_f = L_Xf_G @ G_f\n\n    # Initialize a references object that simplifies handling external forces.\n    references = js.references.JaxSimModelReferences.build(\n        model=model,\n        data=data0,\n        velocity_representation=velocity_representation,\n    )\n\n    # Apply a link forces to the base link.\n    with references.switch_velocity_representation(VelRepr.Body):\n        references = references.apply_link_forces(\n            forces=jnp.atleast_2d(L_f),\n            link_names=model.link_names()[0:1],\n            model=model,\n            data=data0,\n            additive=False,\n        )\n\n    # Initialize the simulation horizon.\n    tf = 0.5\n    T_ns = jnp.arange(start=0, stop=tf * 1e9, step=model.time_step * 1e9, dtype=int)\n\n    # Copy the initial data...\n    data = data0.copy()\n\n    # ... and step the simulation.\n    for _ in T_ns:\n\n        data = js.model.step(\n            model=model,\n            data=data,\n            link_forces=references.link_forces(model, data),\n        )\n\n    # Check that the box didn't move.\n    assert_allclose(data.base_position, data0.base_position)\n    assert_allclose(data.base_orientation, data0.base_orientation)\n\n\ndef test_box_with_zero_gravity(\n    jaxsim_model_box: js.model.JaxSimModel,\n    velocity_representation: VelRepr,\n    prng_key: jnp.ndarray,\n):\n\n    model = jaxsim_model_box\n\n    # Move the terrain (almost) infinitely far away from the box.\n    with model.editable(validate=False) as model:\n        model.terrain = jaxsim.terrain.FlatTerrain.build(height=-1e9)\n        model.gravity = 0.0\n\n    # Split the PRNG key.\n    _, subkey = jax.random.split(prng_key, num=2)\n\n    # Build the data of the model.\n    data0 = js.data.JaxSimModelData.build(\n        model=model,\n        base_position=jax.random.uniform(subkey, shape=(3,)),\n        velocity_representation=velocity_representation,\n    )\n\n    # Initialize a references object that simplifies handling external forces.\n    references = js.references.JaxSimModelReferences.build(\n        model=model,\n        data=data0,\n        velocity_representation=velocity_representation,\n    )\n\n    # Apply a link forces to the base link.\n    with references.switch_velocity_representation(jaxsim.VelRepr.Mixed):\n\n        # Generate a random linear force.\n        # We enforce them to be the same for all velocity representations so that\n        # we can compare their outcomes.\n        LW_f = 10.0 * (\n            jax.random.uniform(jax.random.key(0), shape=(model.number_of_links(), 6))\n            .at[:, 3:]\n            .set(jnp.zeros(3))\n        )\n\n        # Note that the context manager does not switch back the newly created\n        # `references` (that is not the yielded object) to the original representation.\n        # In the simulation loop below, we need to make sure that we switch both `data`\n        # and `references` to the same representation before extracting the information\n        # passed to the step function.\n        references = references.apply_link_forces(\n            forces=jnp.atleast_2d(LW_f),\n            link_names=model.link_names(),\n            model=model,\n            data=data0,\n            additive=False,\n        )\n\n    tf = 0.01\n    T = jnp.arange(start=0, stop=tf * 1e9, step=model.time_step * 1e9, dtype=int)\n\n    # Copy the initial data...\n    data = data0.copy()\n\n    # ... and step the simulation.\n    for _ in T:\n        with (\n            data.switch_velocity_representation(velocity_representation),\n            references.switch_velocity_representation(velocity_representation),\n        ):\n            data = js.model.step(\n                model=model,\n                data=data,\n                link_forces=references.link_forces(model=model, data=data),\n            )\n\n    # Check that the box moved as expected.\n    assert_allclose(\n        data.base_position,\n        data0.base_position\n        + 0.5 * LW_f[:, :3].squeeze() / js.model.total_mass(model=model) * tf**2,\n        atol=1e-3,\n    )\n\n\ndef run_simulation(\n    model: js.model.JaxSimModel,\n    data_t0: js.data.JaxSimModelData,\n    tf: jtp.FloatLike,\n) -> js.data.JaxSimModelData:\n\n    # Initialize the integration horizon.\n    T_ns = jnp.arange(\n        start=0.0, stop=int(tf * 1e9), step=int(model.time_step * 1e9)\n    ).astype(int)\n\n    # Initialize the simulation data.\n    data = data_t0.copy()\n\n    for _ in T_ns:\n\n        data = js.model.step(\n            model=model,\n            data=data,\n        )\n\n    return data\n\n\ndef test_simulation_with_soft_contacts(\n    jaxsim_model_box: js.model.JaxSimModel, integrator\n):\n\n    model = jaxsim_model_box\n\n    # Define the maximum penetration of each collidable point at steady state.\n    max_penetration = 0.001\n\n    with model.editable(validate=False) as model:\n\n        model.contact_model = jaxsim.rbda.contacts.SoftContacts.build()\n        model.contact_params = js.contact.estimate_good_contact_parameters(\n            model=model,\n            number_of_active_collidable_points_steady_state=4,\n            static_friction_coefficient=1.0,\n            damping_ratio=1.0,\n            max_penetration=max_penetration,\n        )\n\n        # Enable a subset of the collidable points.\n        enabled_collidable_points_mask = np.zeros(\n            len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool\n        )\n        enabled_collidable_points_mask[[0, 1, 2, 3]] = True\n        model.kin_dyn_parameters.contact_parameters.enabled = tuple(\n            enabled_collidable_points_mask.tolist()\n        )\n\n    assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4\n\n    # Check jaxsim_model_box@conftest.py.\n    box_height = 0.1\n\n    # Build the data of the model.\n    data_t0 = js.data.JaxSimModelData.build(\n        model=model,\n        base_position=jnp.array([0.0, 0.0, box_height * 2]),\n        velocity_representation=VelRepr.Inertial,\n    )\n\n    # ===========================================\n    # Run the simulation and test the final state\n    # ===========================================\n\n    data_tf = run_simulation(model=model, data_t0=data_t0, tf=1.0)\n\n    assert_allclose(data_tf.base_position[0:2], data_t0.base_position[0:2])\n    assert_allclose(data_tf.base_position[2] + max_penetration, box_height / 2)\n\n\ndef test_simulation_with_rigid_contacts(\n    jaxsim_model_box: js.model.JaxSimModel, integrator\n):\n\n    model = jaxsim_model_box\n\n    with model.editable(validate=False) as model:\n\n        # In order to achieve almost no penetration, we need to use a fairly large\n        # Baumgarte stabilization term.\n        model.contact_model = jaxsim.rbda.contacts.RigidContacts.build(\n            solver_options={\"solver_tol\": 1e-3}\n        )\n        model.contact_params = model.contact_model._parameters_class(K=1e5)\n\n        # Enable a subset of the collidable points.\n        enabled_collidable_points_mask = np.zeros(\n            len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool\n        )\n        enabled_collidable_points_mask[[0, 1, 2, 3]] = True\n        model.kin_dyn_parameters.contact_parameters.enabled = tuple(\n            enabled_collidable_points_mask.tolist()\n        )\n\n    assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4\n\n    # Initialize the maximum penetration of each collidable point at steady state.\n    # This model is rigid, so we expect (almost) no penetration.\n    max_penetration = 0.000\n\n    # Check jaxsim_model_box@conftest.py.\n    box_height = 0.1\n\n    # Build the data of the model.\n    data_t0 = js.data.JaxSimModelData.build(\n        model=model,\n        base_position=jnp.array([0.0, 0.0, box_height * 2]),\n        velocity_representation=VelRepr.Inertial,\n    )\n\n    # ===========================================\n    # Run the simulation and test the final state\n    # ===========================================\n\n    data_tf = run_simulation(model=model, data_t0=data_t0, tf=1.0)\n\n    assert_allclose(data_tf.base_position[0:2], data_t0.base_position[0:2])\n    assert_allclose(data_tf.base_position[2] + max_penetration, box_height / 2)\n\n\ndef test_simulation_with_relaxed_rigid_contacts(\n    jaxsim_model_box: js.model.JaxSimModel, integrator\n):\n\n    model = jaxsim_model_box\n\n    with model.editable(validate=False) as model:\n\n        model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts.build(\n            solver_options={\"tol\": 1e-3},\n        )\n        model.contact_params = model.contact_model._parameters_class()\n\n        # Enable a subset of the collidable points.\n        enabled_collidable_points_mask = np.zeros(\n            len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool\n        )\n        enabled_collidable_points_mask[[0, 1, 2, 3]] = True\n        model.kin_dyn_parameters.contact_parameters.enabled = tuple(\n            enabled_collidable_points_mask.tolist()\n        )\n        model.integrator = integrator\n\n    assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4\n\n    # Initialize the maximum penetration of each collidable point at steady state.\n    # This model is quasi-rigid, so we expect (almost) no penetration.\n    max_penetration = 0.000\n\n    # Check jaxsim_model_box@conftest.py.\n    box_height = 0.1\n\n    # Build the data of the model.\n    data_t0 = js.data.JaxSimModelData.build(\n        model=model,\n        base_position=jnp.array([0.0, 0.0, box_height * 2]),\n        velocity_representation=VelRepr.Inertial,\n    )\n\n    # ===========================================\n    # Run the simulation and test the final state\n    # ===========================================\n\n    data_tf = run_simulation(model=model, data_t0=data_t0, tf=1.0)\n\n    # With this contact model, we need to slightly increase the tolerances.\n    assert_allclose(data_tf.base_position[0:2], data_t0.base_position[0:2], atol=1e-5)\n    assert_allclose(\n        data_tf.base_position[2] + max_penetration, box_height / 2, atol=1e-4\n    )\n\n\ndef test_joint_limits(\n    jaxsim_model_single_pendulum: js.model.JaxSimModel,\n):\n\n    model = jaxsim_model_single_pendulum\n\n    with model.editable(validate=False) as model:\n        model.kin_dyn_parameters.joint_parameters.position_limits_max = jnp.atleast_1d(\n            jnp.array(1.5708)\n        )\n        model.kin_dyn_parameters.joint_parameters.position_limits_min = jnp.atleast_1d(\n            jnp.array(-1.5708)\n        )\n        model.kin_dyn_parameters.joint_parameters.position_limit_spring = (\n            jnp.atleast_1d(jnp.array(75.0))\n        )\n        model.kin_dyn_parameters.joint_parameters.position_limit_damper = (\n            jnp.atleast_1d(jnp.array(0.1))\n        )\n\n    position_limits_min, position_limits_max = js.joint.position_limits(model=model)\n\n    data = js.data.JaxSimModelData.build(\n        model=model,\n        velocity_representation=VelRepr.Inertial,\n    )\n\n    theta = 10 * np.pi / 180\n\n    # Define a tolerance since the spring-damper model does\n    # not guarantee that the joint position will be exactly\n    # below the limit.\n    tolerance = theta * 0.10\n\n    # Test minimum joint position limits.\n    data_t0 = data.replace(model=model, joint_positions=position_limits_min - theta)\n\n    model = model.replace(time_step=0.005, validate=False)\n    data_tf = run_simulation(model=model, data_t0=data_t0, tf=3.0)\n\n    assert (\n        np.min(np.array(data_tf.joint_positions), axis=0) + tolerance\n        >= position_limits_min\n    )\n\n    # Test maximum joint position limits.\n    data_t0 = data.replace(model=model, joint_positions=position_limits_max - theta)\n\n    model = model.replace(time_step=0.001)\n    data_tf = run_simulation(model=model, data_t0=data_t0, tf=3.0)\n\n    assert (\n        np.max(np.array(data_tf.joint_positions), axis=0) - tolerance\n        <= position_limits_max\n    )\n\n\n@pytest.mark.parametrize(\n    \"initial_joint_positions\",\n    [\n        jnp.array([0, 0]),\n        np.pi / 180 * jnp.array([5, 0]),\n    ],\n)\ndef test_simulation_with_kinematic_constraints_double_pendulum(\n    jaxsim_model_double_pendulum: js.model.JaxSimModel,\n    initial_joint_positions: jtp.Array,\n):\n\n    # ========\n    # Arrange\n    # ========\n\n    tf = 1.0  # Final simulation time in seconds.\n\n    model = jaxsim_model_double_pendulum\n\n    frame_1_name = \"right_link_extremity_frame\"\n    frame_2_name = \"left_link_extremity_frame\"\n    frame_1_idx = js.frame.name_to_idx(model=model, frame_name=frame_1_name)\n    frame_2_idx = js.frame.name_to_idx(model=model, frame_name=frame_2_name)\n\n    # Define the kinematic constraints.\n    constraints = js.kin_dyn_parameters.ConstraintMap()\n    constraints = constraints.add_constraint(\n        model=model,\n        frame_idx_1=frame_1_idx,\n        frame_idx_2=frame_2_idx,\n        constraint_type=ConstraintType.Weld,\n    )\n\n    # Set the constraints in the model.\n    with model.editable(validate=False) as model:\n        model.kin_dyn_parameters.constraints = constraints\n        model.gravity = 0.0\n\n    # Build the initial data for the model.\n    data_t0 = js.data.JaxSimModelData.build(\n        model=model,\n        velocity_representation=VelRepr.Inertial,\n        joint_positions=initial_joint_positions,\n    )\n\n    # ====\n    # Act\n    # ====\n\n    # Simulate the model for a given time and time step.\n    data_tf = run_simulation(model=model, data_t0=data_t0, tf=tf)\n\n    # =========\n    # Assert\n    # =========\n\n    # Assert that the chosen frames exist in the model\n    assert frame_1_name in model.frame_names()\n    assert frame_2_name in model.frame_names()\n\n    # Assert that the joint positions are now equal\n    actual_delta_s_tf = jnp.abs(data_tf.joint_positions[0] - data_tf.joint_positions[1])\n    expected_delta_s_tf = 0.0\n\n    assert_allclose(\n        expected_delta_s_tf,\n        actual_delta_s_tf,\n        atol=1e-2,\n        err_msg=f\"Position difference [deg]: {actual_delta_s_tf * 180 / np.pi}\",\n    )\n\n\ndef test_simulation_with_kinematic_constraints_cartpole(\n    jaxsim_model_cartpole: js.model.JaxSimModel,\n):\n    # ========\n    # Arrange\n    # ========\n\n    tf = 1.0  # Final simulation time in seconds.\n\n    model = jaxsim_model_cartpole\n\n    frame_1_name = \"cart_frame\"\n    frame_2_name = \"rail_frame\"\n    frame_1_idx = js.frame.name_to_idx(model=model, frame_name=frame_1_name)\n    frame_2_idx = js.frame.name_to_idx(model=model, frame_name=frame_2_name)\n\n    # Define the kinematic constraints.\n    constraints = js.kin_dyn_parameters.ConstraintMap()\n    constraints = constraints.add_constraint(\n        model,\n        frame_1_idx,\n        frame_2_idx,\n        ConstraintType.Weld,\n    )\n\n    # Set the initial joint positions with the cart displaced from the rail zero position.\n    initial_joint_positions = jnp.array([0.05, 0.0])\n\n    # Set the constraints in the model.\n    with model.editable(validate=False) as model:\n        model.kin_dyn_parameters.constraints = constraints\n\n    # Build the initial data for the model.\n    data_t0 = js.data.JaxSimModelData.build(\n        model=model,\n        velocity_representation=VelRepr.Inertial,\n        joint_positions=initial_joint_positions,\n    )\n\n    # ====\n    # Act\n    # ====\n\n    # Simulate the model for a given time and time step.\n    data_tf = run_simulation(model=model, data_t0=data_t0, tf=tf)\n\n    H_frame1 = js.frame.transform(\n        model=model,\n        data=data_tf,\n        frame_index=frame_1_idx,\n    )\n    H_frame2 = js.frame.transform(\n        model=model,\n        data=data_tf,\n        frame_index=frame_2_idx,\n    )\n\n    # =========\n    # Assert\n    # =========\n\n    # Assert that the chosen frames exist in the model\n    assert frame_1_name in model.frame_names()\n    assert frame_2_name in model.frame_names()\n\n    # Assert that the two frames are in the same pose\n    actual_frame_error = jnp.linalg.inv(H_frame1) @ H_frame2\n    expected_frame_error = jnp.eye(4)\n\n    assert_allclose(actual_frame_error, expected_frame_error, atol=1e-3)\n\n\ndef test_simulation_with_kinematic_constraints_4_bar_linkage(\n    jaxsim_model_4_bar_linkage: js.model.JaxSimModel,\n):\n    \"\"\"Test kinematic weld constraint on 4-bar linkage model.\"\"\"\n\n    # ========\n    # Arrange\n    # ========\n\n    tf = 1.0  # Final simulation time in seconds.\n    model = jaxsim_model_4_bar_linkage\n\n    frame_1_name = \"BC1_frame\"\n    frame_2_name = \"BC2_frame\"\n    frame_1_idx = js.frame.name_to_idx(model=model, frame_name=frame_1_name)\n    frame_2_idx = js.frame.name_to_idx(model=model, frame_name=frame_2_name)\n\n    # Define the kinematic constraints.\n    constraints = js.kin_dyn_parameters.ConstraintMap()\n    constraints = constraints.add_constraint(\n        model=model,\n        frame_idx_1=frame_1_idx,\n        frame_idx_2=frame_2_idx,\n        constraint_type=ConstraintType.Weld,\n        K_P=1e4,\n    )\n\n    # Set the constraints in the model.\n    with model.editable(validate=False) as model:\n        model.kin_dyn_parameters.constraints = constraints\n\n    # Build the initial data for the model (default base pose is fine).\n    data_t0 = js.data.JaxSimModelData.build(\n        model=model,\n        velocity_representation=VelRepr.Inertial,\n        base_position=jnp.array([0.0, 0.0, 0.10]),\n    )\n\n    # ====\n    # Act\n    # ====\n    data_tf = run_simulation(model=model, data_t0=data_t0, tf=tf)\n\n    H_frame1 = js.frame.transform(\n        model=model,\n        data=data_tf,\n        frame_index=frame_1_idx,\n    )\n    H_frame2 = js.frame.transform(\n        model=model,\n        data=data_tf,\n        frame_index=frame_2_idx,\n    )\n\n    # =========\n    # Assert\n    # =========\n    assert frame_1_name in model.frame_names()\n    assert frame_2_name in model.frame_names()\n\n    # Position check\n    pos1 = H_frame1[:3, 3]\n    pos2 = H_frame2[:3, 3]\n    assert_allclose(pos1, pos2, atol=1e-5)\n\n    # Orientation check\n    R1 = H_frame1[:3, :3]\n    R2 = H_frame2[:3, :3]\n    R_err = R1.T @ R2\n    assert_allclose(R_err, jnp.eye(3), atol=1e-3)\n"
  },
  {
    "path": "tests/test_visualizer.py",
    "content": "import pytest\nimport rod\n\nfrom jaxsim.mujoco import ModelToMjcf\nfrom jaxsim.mujoco.loaders import MujocoCamera\n\n\n@pytest.fixture\ndef mujoco_camera():\n\n    return MujocoCamera.build_from_target_view(\n        camera_name=\"test_camera\",\n        lookat=(0, 0, 0),\n        distance=1,\n        azimuth=0,\n        elevation=0,\n        fovy=45,\n        degrees=True,\n    )\n\n\ndef test_urdf_loading(jaxsim_model_single_pendulum, mujoco_camera):\n    model = jaxsim_model_single_pendulum.built_from\n\n    _ = ModelToMjcf.convert(model=model, cameras=mujoco_camera)\n\n\ndef test_sdf_loading(jaxsim_model_single_pendulum, mujoco_camera):\n\n    model = rod.Sdf.load(sdf=jaxsim_model_single_pendulum.built_from).serialize(\n        pretty=True\n    )\n\n    _ = ModelToMjcf.convert(model=model, cameras=mujoco_camera)\n\n\ndef test_rod_loading(jaxsim_model_single_pendulum, mujoco_camera):\n\n    model = rod.Sdf.load(sdf=jaxsim_model_single_pendulum.built_from).models()[0]\n\n    _ = ModelToMjcf.convert(model=model, cameras=mujoco_camera)\n\n\ndef test_heightmap(jaxsim_model_single_pendulum, mujoco_camera):\n\n    model = rod.Sdf.load(sdf=jaxsim_model_single_pendulum.built_from).models()[0]\n\n    _ = ModelToMjcf.convert(\n        model=model,\n        cameras=mujoco_camera,\n        heightmap=True,\n        heightmap_samples_xy=(51, 51),\n    )\n\n\ndef test_inclined_plane(jaxsim_model_single_pendulum, mujoco_camera):\n\n    model = rod.Sdf.load(sdf=jaxsim_model_single_pendulum.built_from).models()[0]\n\n    _ = ModelToMjcf.convert(\n        model=model,\n        cameras=mujoco_camera,\n        plane_normal=(0.3, 0.3, 0.3),\n    )\n"
  },
  {
    "path": "tests/utils.py",
    "content": "from __future__ import annotations\n\nimport dataclasses\nimport pathlib\n\nimport idyntree.bindings as idt\nimport numpy as np\nimport numpy.typing as npt\n\nimport jaxsim.api as js\nfrom jaxsim import VelRepr\n\n\ndef assert_allclose(actual, desired, rtol=1e-7, atol=1e-9, err_msg=\"\"):\n    \"\"\"\n    Assert allclose with custom default tolerances.\n    Normalizes only signed zeros using np.copysign.\n    \"\"\"\n    actual = np.asarray(actual, dtype=float)\n    desired = np.asarray(desired, dtype=float)\n\n    # Normalize zeros to avoid -0.0 vs 0.0 mismatches.\n    actual = actual + 0.0\n    desired = desired + 0.0\n\n    np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, err_msg=err_msg)\n\n\ndef build_kindyncomputations_from_jaxsim_model(\n    model: js.model.JaxSimModel,\n    data: js.data.JaxSimModelData,\n    considered_joints: list[str] | None = None,\n    removed_joint_positions: dict[str, npt.NDArray | float | int] | None = None,\n) -> KinDynComputations:\n    \"\"\"\n    Build a `KinDynComputations` from `JaxSimModel` and `JaxSimModelData`.\n\n    Args:\n        model: The `JaxSimModel` from which to build the `KinDynComputations`.\n        data: The `JaxSimModelData` from which to build the `KinDynComputations`.\n        considered_joints:\n            The list of joint names to consider in the `KinDynComputations`.\n        removed_joint_positions:\n            A dictionary defining the positions of the removed joints (default is 0).\n\n    Returns:\n        The `KinDynComputations` built from the `JaxSimModel` and `JaxSimModelData`.\n\n    Note:\n        Only `JaxSimModel` built from URDF files are supported.\n\n    \"\"\"\n\n    if (\n        isinstance(model.built_from, pathlib.Path)\n        and model.built_from.suffix != \".urdf\"\n    ) or (isinstance(model.built_from, str) and \"<robot\" not in model.built_from):\n        raise ValueError(\"iDynTree only supports URDF models\")\n\n    if not data.valid(model=model):\n        raise ValueError(\"Invalid data object for the provided model.\")\n\n    # By default, enforce iDynTree to use the same serialization of the JaxSimModel.\n    considered_joints = (\n        considered_joints if considered_joints is not None else model.joint_names()\n    )\n\n    # Get the default positions already stored in the model description.\n    removed_joint_positions_default = {\n        str(j.name): float(j.initial_position)\n        for j in model.description.joints_removed\n        if j.name not in considered_joints\n    }\n\n    # Pass this dict even if there are no removed joints.\n    removed_joint_positions = removed_joint_positions_default | (\n        removed_joint_positions\n        if removed_joint_positions is not None\n        else dict(\n            zip(\n                model.joint_names(),\n                data.joint_positions,\n                strict=True,\n            )\n        )\n    )\n\n    # Create the KinDynComputations from the same URDF model.\n    kin_dyn = KinDynComputations.build(\n        urdf=model.built_from,\n        considered_joints=considered_joints,\n        vel_repr=data.velocity_representation,\n        gravity=np.array([0, 0, model.gravity]),\n        removed_joint_positions=removed_joint_positions,\n    )\n\n    # Copy the state of the JaxSim model.\n    kin_dyn = store_jaxsim_data_in_kindyncomputations(data=data, kin_dyn=kin_dyn)\n\n    return kin_dyn\n\n\ndef store_jaxsim_data_in_kindyncomputations(\n    data: js.data.JaxSimModelData, kin_dyn: KinDynComputations\n) -> KinDynComputations:\n    \"\"\"\n    Store the state of a `JaxSimModelData` in `KinDynComputations`.\n\n    Args:\n        data:\n            The `JaxSimModelData` providing the desired state to copy.\n        kin_dyn:\n            The `KinDynComputations` in which to store the state of `JaxSimModelData`.\n\n    Returns:\n        The updated `KinDynComputations` with the state of `JaxSimModelData`.\n\n    \"\"\"\n\n    if kin_dyn.dofs() != data.joint_positions.size:\n        raise ValueError(data)\n\n    with data.switch_velocity_representation(kin_dyn.vel_repr):\n        kin_dyn.set_robot_state(\n            joint_positions=np.array(data.joint_positions),\n            joint_velocities=np.array(data.joint_velocities),\n            base_transform=np.array(data._base_transform),\n            base_velocity=np.array(data.base_velocity),\n        )\n\n    return kin_dyn\n\n\n@dataclasses.dataclass\nclass KinDynComputations:\n    \"\"\"High-level wrapper of the iDynTree KinDynComputations class.\"\"\"\n\n    vel_repr: VelRepr\n    gravity: npt.NDArray\n    kin_dyn: idt.KinDynComputations\n\n    @staticmethod\n    def build(\n        urdf: pathlib.Path | str,\n        considered_joints: list[str] | None = None,\n        vel_repr: VelRepr = VelRepr.Inertial,\n        gravity: npt.NDArray = np.array([0, 0, -10.0]),\n        removed_joint_positions: dict[str, npt.NDArray | float | int] | None = None,\n    ) -> KinDynComputations:\n\n        # Read the URDF description.\n        urdf_string = urdf.read_text() if isinstance(urdf, pathlib.Path) else urdf\n\n        # Create the model loader.\n        mdl_loader = idt.ModelLoader()\n\n        # Handle removed_joint_positions if None.\n        removed_joint_positions = (\n            {name: float(pos) for name, pos in removed_joint_positions.items()}\n            if removed_joint_positions is not None\n            else {}\n        )\n\n        # Load the URDF description.\n        if not (\n            mdl_loader.loadModelFromString(urdf_string)\n            if considered_joints is None\n            else mdl_loader.loadReducedModelFromString(\n                urdf_string, considered_joints, removed_joint_positions\n            )\n        ):\n            raise RuntimeError(\"Failed to load URDF description\")\n\n        # Create KinDynComputations and insert the model.\n        kindyn = idt.KinDynComputations()\n\n        if not kindyn.loadRobotModel(mdl_loader.model()):\n            raise RuntimeError(\"Failed to load model\")\n\n        vel_repr_to_idyntree = {\n            VelRepr.Inertial: idt.INERTIAL_FIXED_REPRESENTATION,\n            VelRepr.Body: idt.BODY_FIXED_REPRESENTATION,\n            VelRepr.Mixed: idt.MIXED_REPRESENTATION,\n        }\n\n        # Configure the frame representation.\n        if not kindyn.setFrameVelocityRepresentation(vel_repr_to_idyntree[vel_repr]):\n            raise RuntimeError(\"Failed to set the frame representation\")\n\n        return KinDynComputations(\n            kin_dyn=kindyn,\n            vel_repr=vel_repr,\n            gravity=np.array(gravity).squeeze(),\n        )\n\n    def set_robot_state(\n        self,\n        joint_positions: npt.NDArray | None = None,\n        joint_velocities: npt.NDArray | None = None,\n        base_transform: npt.NDArray = np.eye(4),\n        base_velocity: npt.NDArray = np.zeros(6),\n        world_gravity: npt.NDArray | None = None,\n    ) -> None:\n\n        joint_positions = (\n            joint_positions if joint_positions is not None else np.zeros(self.dofs())\n        )\n\n        joint_velocities = (\n            joint_velocities if joint_velocities is not None else np.zeros(self.dofs())\n        )\n\n        gravity = world_gravity if world_gravity is not None else self.gravity\n\n        if joint_positions.size != self.dofs():\n            raise ValueError(joint_positions.size, self.dofs())\n\n        if joint_velocities.size != self.dofs():\n            raise ValueError(joint_velocities.size, self.dofs())\n\n        if gravity.size != 3:\n            raise ValueError(gravity.size, 3)\n\n        if base_transform.shape != (4, 4):\n            raise ValueError(base_transform.shape, (4, 4))\n\n        if base_velocity.size != 6:\n            raise ValueError(base_velocity.size)\n\n        g = idt.Vector3().FromPython(np.array(gravity))\n        s = idt.VectorDynSize().FromPython(np.array(joint_positions))\n        s_dot = idt.VectorDynSize().FromPython(np.array(joint_velocities))\n\n        p = idt.Position(*[float(i) for i in np.array(base_transform[0:3, 3])])\n        R = idt.Rotation()\n        R = R.FromPython(np.array(base_transform[0:3, 0:3]))\n        world_H_base = idt.Transform()\n        world_H_base.setPosition(p)\n        world_H_base.setRotation(R)\n\n        v_WB = idt.Twist().FromPython(base_velocity)\n\n        if not self.kin_dyn.setRobotState(world_H_base, s, v_WB, s_dot, g):\n            raise RuntimeError(\"Failed to set the robot state\")\n\n        # Update stored gravity.\n        self.gravity = gravity\n\n    def dofs(self) -> int:\n\n        return self.kin_dyn.getNrOfDegreesOfFreedom()\n\n    def joint_names(self) -> list[str]:\n\n        model: idt.Model = self.kin_dyn.model()\n        return [model.getJointName(i) for i in range(model.getNrOfJoints())]\n\n    def link_names(self) -> list[str]:\n\n        return [\n            self.kin_dyn.getFrameName(i) for i in range(self.kin_dyn.getNrOfLinks())\n        ]\n\n    def frame_names(self) -> list[str]:\n\n        return [\n            self.kin_dyn.getFrameName(i)\n            for i in range(self.kin_dyn.getNrOfLinks(), self.kin_dyn.getNrOfFrames())\n        ]\n\n    def joint_positions(self) -> npt.NDArray:\n\n        vector = idt.VectorDynSize()\n\n        if not self.kin_dyn.getJointPos(vector):\n            raise RuntimeError(\"Failed to extract joint positions\")\n\n        return vector.toNumPy()\n\n    def joint_velocities(self) -> npt.NDArray:\n\n        vector = idt.VectorDynSize()\n\n        if not self.kin_dyn.getJointVel(vector):\n            raise RuntimeError(\"Failed to extract joint velocities\")\n\n        return vector.toNumPy()\n\n    def jacobian_frame(self, frame_name: str) -> npt.NDArray:\n\n        if self.kin_dyn.getFrameIndex(frame_name) < 0:\n            raise ValueError(f\"Frame '{frame_name}' does not exist\")\n\n        J = idt.MatrixDynSize(6, 6 + self.dofs())\n\n        if not self.kin_dyn.getFrameFreeFloatingJacobian(frame_name, J):\n            raise RuntimeError(\"Failed to get the frame free-floating jacobian\")\n\n        return J.toNumPy()\n\n    def total_mass(self) -> float:\n\n        model: idt.Model = self.kin_dyn.model()\n        return model.getTotalMass()\n\n    def link_spatial_inertia(self, link_name: str) -> npt.NDArray:\n\n        if link_name not in self.link_names():\n            raise ValueError(link_name)\n\n        model = self.kin_dyn.model()\n        link: idt.Link = model.getLink(model.getLinkIndex(link_name))\n\n        return link.inertia().asMatrix().toNumPy()\n\n    def link_mass(self, link_name: str) -> float:\n\n        if link_name not in self.link_names():\n            raise ValueError(link_name)\n\n        model = self.kin_dyn.model()\n        link: idt.Link = model.getLink(model.getLinkIndex(link_name))\n\n        return link.getInertia().asVector().toNumPy()[0]\n\n    def floating_base_frame(self) -> str:\n\n        return self.kin_dyn.getFloatingBase()\n\n    def frame_transform(self, frame_name: str) -> npt.NDArray:\n\n        if self.kin_dyn.getFrameIndex(frame_name) < 0:\n            raise ValueError(f\"Frame '{frame_name}' does not exist\")\n\n        if frame_name == self.floating_base_frame():\n            H_idt = self.kin_dyn.getWorldBaseTransform()\n        else:\n            H_idt = self.kin_dyn.getWorldTransform(frame_name)\n\n        H = np.eye(4)\n        H[0:3, 3] = H_idt.getPosition().toNumPy()\n        H[0:3, 0:3] = H_idt.getRotation().toNumPy()\n\n        return H\n\n    def frame_relative_transform(\n        self, ref_frame_name: str, frame_name: str\n    ) -> npt.NDArray:\n\n        if self.kin_dyn.getFrameIndex(ref_frame_name) < 0:\n            raise ValueError(f\"Frame '{ref_frame_name}' does not exist\")\n\n        if self.kin_dyn.getFrameIndex(frame_name) < 0:\n            raise ValueError(f\"Frame '{frame_name}' does not exist\")\n\n        ref_H_frame: idt.Transform = self.kin_dyn.getRelativeTransform(\n            ref_frame_name, frame_name\n        )\n\n        H = np.eye(4)\n        H[0:3, 3] = ref_H_frame.getPosition().toNumPy()\n        H[0:3, 0:3] = ref_H_frame.getRotation().toNumPy()\n\n        return H\n\n    def frame_parent_link_name(self, frame_name: str) -> str:\n        return self.kin_dyn.model().getLinkName(\n            self.kin_dyn.model().getFrameLink(\n                self.kin_dyn.model().getFrameIndex(frame_name)\n            )\n        )\n\n    def base_velocity(self) -> npt.NDArray:\n\n        nu = idt.VectorDynSize()\n\n        if not self.kin_dyn.getModelVel(nu):\n            raise RuntimeError(\"Failed to get the model velocity\")\n\n        return nu.toNumPy()[0:6]\n\n    def frame_velocity(self, frame_name: str) -> npt.NDArray:\n\n        if self.kin_dyn.getFrameIndex(frame_name) < 0:\n            raise ValueError(f\"Frame '{frame_name}' does not exist\")\n\n        v_WF = self.kin_dyn.getFrameVel(frame_name)\n\n        return v_WF.toNumPy()\n\n    def frame_bias_acc(self, frame_name: str) -> npt.NDArray:\n\n        if self.kin_dyn.getFrameIndex(frame_name) < 0:\n            raise ValueError(f\"Frame '{frame_name}' does not exist\")\n\n        J̇ν = self.kin_dyn.getFrameBiasAcc(frame_name)\n\n        return J̇ν.toNumPy()\n\n    def com_position(self) -> npt.NDArray:\n\n        W_p_G = self.kin_dyn.getCenterOfMassPosition()\n        return W_p_G.toNumPy()\n\n    def com_velocity(self) -> npt.NDArray:\n\n        W_ṗ_G = self.kin_dyn.getCenterOfMassVelocity()\n        return W_ṗ_G.toNumPy()\n\n    def com_bias_acceleration(self) -> npt.NDArray:\n\n        return self.kin_dyn.getCenterOfMassBiasAcc().toNumPy()\n\n    def mass_matrix(self) -> npt.NDArray:\n\n        M = idt.MatrixDynSize()\n\n        if not self.kin_dyn.getFreeFloatingMassMatrix(M):\n            raise RuntimeError(\"Failed to get the free floating mass matrix\")\n\n        return M.toNumPy()\n\n    def bias_forces(self) -> npt.NDArray:\n\n        h = idt.FreeFloatingGeneralizedTorques(self.kin_dyn.model())\n\n        if not self.kin_dyn.generalizedBiasForces(h):\n            raise RuntimeError(\"Failed to get the generalized bias forces\")\n\n        base_wrench: idt.Wrench = h.baseWrench()\n        joint_torques: idt.JointDOFsDoubleArray = h.jointTorques()\n\n        return np.hstack(\n            [base_wrench.toNumPy().flatten(), joint_torques.toNumPy().flatten()]\n        )\n\n    def gravity_forces(self) -> npt.NDArray:\n\n        g = idt.FreeFloatingGeneralizedTorques(self.kin_dyn.model())\n\n        if not self.kin_dyn.generalizedGravityForces(g):\n            raise RuntimeError(\"Failed to get the generalized gravity forces\")\n\n        base_wrench: idt.Wrench = g.baseWrench()\n        joint_torques: idt.JointDOFsDoubleArray = g.jointTorques()\n\n        return np.hstack(\n            [base_wrench.toNumPy().flatten(), joint_torques.toNumPy().flatten()]\n        )\n\n    def total_momentum(self) -> npt.NDArray:\n\n        return self.kin_dyn.getLinearAngularMomentum().toNumPy().flatten()\n\n    def centroidal_momentum(self) -> npt.NDArray:\n\n        return self.kin_dyn.getCentroidalTotalMomentum().toNumPy().flatten()\n\n    def total_momentum_jacobian(self) -> npt.NDArray:\n\n        Jh = idt.MatrixDynSize()\n\n        if not self.kin_dyn.getLinearAngularMomentumJacobian(Jh):\n            raise RuntimeError(\"Failed to get the total momentum jacobian\")\n\n        return Jh.toNumPy()\n\n    def centroidal_momentum_jacobian(self) -> npt.NDArray:\n\n        Jh = idt.MatrixDynSize()\n\n        if not self.kin_dyn.getCentroidalTotalMomentumJacobian(Jh):\n            raise RuntimeError(\"Failed to get the centroidal momentum jacobian\")\n\n        return Jh.toNumPy()\n\n    def locked_spatial_inertia(self) -> npt.NDArray:\n\n        return self.kin_dyn.getRobotLockedInertia().asMatrix().toNumPy()\n\n    def locked_centroidal_spatial_inertia(self) -> npt.NDArray:\n\n        return self.kin_dyn.getCentroidalRobotLockedInertia().asMatrix().toNumPy()\n\n    def average_velocity(self) -> npt.NDArray:\n\n        return self.kin_dyn.getAverageVelocity().toNumPy()\n\n    def average_velocity_jacobian(self) -> npt.NDArray:\n\n        Jh = idt.MatrixDynSize()\n\n        if not self.kin_dyn.getAverageVelocityJacobian(Jh):\n            raise RuntimeError(\"Failed to get the average velocity jacobian\")\n\n        return Jh.toNumPy()\n\n    def average_centroidal_velocity(self) -> npt.NDArray:\n\n        return self.kin_dyn.getCentroidalAverageVelocity().toNumPy()\n\n    def average_centroidal_velocity_jacobian(self) -> npt.NDArray:\n\n        Jh = idt.MatrixDynSize()\n\n        if not self.kin_dyn.getCentroidalAverageVelocityJacobian(Jh):\n            raise RuntimeError(\"Failed to get the average centroidal velocity jacobian\")\n\n        return Jh.toNumPy()\n"
  }
]