Full Code of Physical-Intelligence/openpi for AI

main 54cbaee6ae0c cached
139 files
789.8 KB
196.6k tokens
837 symbols
1 requests
Download .txt
Showing preview only (833K chars total). Download the full file or copy to clipboard to get everything.
Repository: Physical-Intelligence/openpi
Branch: main
Commit: 54cbaee6ae0c
Files: 139
Total size: 789.8 KB

Directory structure:
gitextract_6rhyljlr/

├── .dockerignore
├── .github/
│   ├── CODEOWNERS
│   └── workflows/
│       ├── pre-commit.yml
│       └── test.yml
├── .gitignore
├── .gitmodules
├── .pre-commit-config.yaml
├── .python-version
├── .vscode/
│   └── settings.json
├── CONTRIBUTING.md
├── LICENSE
├── LICENSE_GEMMA.txt
├── README.md
├── docs/
│   ├── docker.md
│   ├── norm_stats.md
│   └── remote_inference.md
├── examples/
│   ├── aloha_real/
│   │   ├── Dockerfile
│   │   ├── README.md
│   │   ├── compose.yml
│   │   ├── constants.py
│   │   ├── convert_aloha_data_to_lerobot.py
│   │   ├── env.py
│   │   ├── main.py
│   │   ├── real_env.py
│   │   ├── requirements.in
│   │   ├── requirements.txt
│   │   ├── robot_utils.py
│   │   └── video_display.py
│   ├── aloha_sim/
│   │   ├── Dockerfile
│   │   ├── README.md
│   │   ├── compose.yml
│   │   ├── env.py
│   │   ├── main.py
│   │   ├── requirements.in
│   │   ├── requirements.txt
│   │   └── saver.py
│   ├── convert_jax_model_to_pytorch.py
│   ├── droid/
│   │   ├── README.md
│   │   ├── README_train.md
│   │   ├── compute_droid_nonidle_ranges.py
│   │   ├── convert_droid_data_to_lerobot.py
│   │   └── main.py
│   ├── inference.ipynb
│   ├── libero/
│   │   ├── Dockerfile
│   │   ├── README.md
│   │   ├── compose.yml
│   │   ├── convert_libero_data_to_lerobot.py
│   │   ├── main.py
│   │   ├── requirements.in
│   │   └── requirements.txt
│   ├── policy_records.ipynb
│   ├── simple_client/
│   │   ├── Dockerfile
│   │   ├── README.md
│   │   ├── compose.yml
│   │   ├── main.py
│   │   ├── requirements.in
│   │   └── requirements.txt
│   └── ur5/
│       └── README.md
├── packages/
│   └── openpi-client/
│       ├── pyproject.toml
│       └── src/
│           └── openpi_client/
│               ├── __init__.py
│               ├── action_chunk_broker.py
│               ├── base_policy.py
│               ├── image_tools.py
│               ├── image_tools_test.py
│               ├── msgpack_numpy.py
│               ├── msgpack_numpy_test.py
│               ├── runtime/
│               │   ├── agent.py
│               │   ├── agents/
│               │   │   └── policy_agent.py
│               │   ├── environment.py
│               │   ├── runtime.py
│               │   └── subscriber.py
│               └── websocket_client_policy.py
├── pyproject.toml
├── scripts/
│   ├── __init__.py
│   ├── compute_norm_stats.py
│   ├── docker/
│   │   ├── compose.yml
│   │   ├── install_docker_ubuntu22.sh
│   │   ├── install_nvidia_container_toolkit.sh
│   │   └── serve_policy.Dockerfile
│   ├── serve_policy.py
│   ├── train.py
│   ├── train_pytorch.py
│   └── train_test.py
└── src/
    └── openpi/
        ├── __init__.py
        ├── conftest.py
        ├── models/
        │   ├── __init__.py
        │   ├── gemma.py
        │   ├── gemma_fast.py
        │   ├── lora.py
        │   ├── lora_test.py
        │   ├── model.py
        │   ├── model_test.py
        │   ├── pi0.py
        │   ├── pi0_config.py
        │   ├── pi0_fast.py
        │   ├── pi0_test.py
        │   ├── siglip.py
        │   ├── tokenizer.py
        │   ├── tokenizer_test.py
        │   ├── utils/
        │   │   └── fsq_tokenizer.py
        │   └── vit.py
        ├── models_pytorch/
        │   ├── gemma_pytorch.py
        │   ├── pi0_pytorch.py
        │   ├── preprocessing_pytorch.py
        │   └── transformers_replace/
        │       └── models/
        │           ├── gemma/
        │           │   ├── configuration_gemma.py
        │           │   └── modeling_gemma.py
        │           ├── paligemma/
        │           │   └── modeling_paligemma.py
        │           └── siglip/
        │               ├── check.py
        │               └── modeling_siglip.py
        ├── policies/
        │   ├── aloha_policy.py
        │   ├── droid_policy.py
        │   ├── libero_policy.py
        │   ├── policy.py
        │   ├── policy_config.py
        │   └── policy_test.py
        ├── py.typed
        ├── serving/
        │   └── websocket_policy_server.py
        ├── shared/
        │   ├── __init__.py
        │   ├── array_typing.py
        │   ├── download.py
        │   ├── download_test.py
        │   ├── image_tools.py
        │   ├── image_tools_test.py
        │   ├── nnx_utils.py
        │   ├── normalize.py
        │   └── normalize_test.py
        ├── training/
        │   ├── checkpoints.py
        │   ├── config.py
        │   ├── data_loader.py
        │   ├── data_loader_test.py
        │   ├── droid_rlds_dataset.py
        │   ├── misc/
        │   │   ├── polaris_config.py
        │   │   └── roboarena_config.py
        │   ├── optimizer.py
        │   ├── sharding.py
        │   ├── utils.py
        │   └── weight_loaders.py
        ├── transforms.py
        └── transforms_test.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .dockerignore
================================================
.venv
checkpoints
data


================================================
FILE: .github/CODEOWNERS
================================================
# The CODEOWNERS file defines individuals or teams that are automatically requested for
# review when someone opens a pull request that modifies certain code. When a draft pull
# request is marked as ready for review, code owners are automatically notified.
#
# See: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
#
# This is a comment.
# Each line is a file pattern followed by one or more owners.

# Global owners.
* @jimmyt857 @Michael-Equi @kvablack

src/openpi/models/ @kvablack
src/openpi/training/ @kvablack

scripts/ @jimmyt857 @kvablack

================================================
FILE: .github/workflows/pre-commit.yml
================================================
name: pre-commit
on:
  push:
    branches:
      - main
  pull_request:
    branches:
      - "*"
jobs:
  pre-commit:
    runs-on: ubuntu-latest
    env:
      GIT_LFS_SKIP_SMUDGE: true
    steps:
      - uses: actions/checkout@v4
      - uses: actions/setup-python@v3
      - uses: pre-commit/action@v3.0.1


================================================
FILE: .github/workflows/test.yml
================================================
name: Test
on:
  pull_request:
    branches:
      - "*"

jobs:
  run_tests:
    name: Run Tests
    runs-on: openpi-verylarge
    env:
      GIT_LFS_SKIP_SMUDGE: true
    steps:
      - uses: actions/checkout@v4

      - name: Install FFmpeg dependencies
        run: |
          sudo apt-get update
          sudo apt-get install -y ffmpeg libavcodec-dev libavformat-dev libavutil-dev

      - name: Install uv
        uses: astral-sh/setup-uv@v5

      - name: Set up Python
        run: uv python install

      - name: Install the project
        run: uv sync --all-extras --dev

      - name: Run tests
        run: uv run pytest --strict-markers -m "not manual"


================================================
FILE: .gitignore
================================================
# Data directories.
assets/
checkpoints/
data/
wandb/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/


================================================
FILE: .gitmodules
================================================
[submodule "third_party/aloha"]
	path = third_party/aloha
	url = https://github.com/Physical-Intelligence/aloha.git
[submodule "third_party/libero"]
	path = third_party/libero
	url = https://github.com/Lifelong-Robot-Learning/LIBERO.git


================================================
FILE: .pre-commit-config.yaml
================================================
exclude: third_party/

repos:
  - repo: https://github.com/astral-sh/uv-pre-commit
    # uv version.
    rev: 0.5.14
    hooks:
      - id: uv-lock
  - repo: https://github.com/astral-sh/ruff-pre-commit
    # Ruff version.
    rev: v0.8.6
    hooks:
      # Run the linter.
      - id: ruff
        args: [--fix]
      - id: ruff-format

================================================
FILE: .python-version
================================================
3.11

================================================
FILE: .vscode/settings.json
================================================
{
    "[python]": {
        "editor.defaultFormatter": "charliermarsh.ruff",
        "editor.formatOnSave": true,
    },
    "python.testing.pytestArgs": [
        "src"
    ],
    "python.testing.unittestEnabled": false,
    "python.testing.pytestEnabled": true
}

================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to openpi

We welcome contributions, improvements, and modifications. Everyone is welcome to use openpi in accordance to the [license](LICENSE). Contributors are also welcome to submit bug reports, feature requests, and pull requests. We can't promise to approve every pull request, and we are a small team with limited bandwidth to review all requests, but we'll give it our best effort. Specifics are described below.

## Issues and feature requests

You are welcome to use the Github [discussion](https://github.com/Physical-Intelligence/openpi/discussions) feature if you would like to discuss something that is not directly reporting an issue or making a feature request. This is suitable for questions about how to use some aspect of openpi, or other topics.

If you found a bug or other issue, please first check that the issue was not already reported (use the search bar on Github under Issues). If the issue has not yet been reported, please include this information when filing a Github issue:

- Your OS type and version and the version of Python you are using
- Code that allows us to reproduce your bug, including all dependencies
- Traceback of any exception
- Any other information that would help us, such as a screenshot

In order for us to address any issue, we must be able to reproduce it, so if you encountered the issue after making modifications to openpi, please reproduce the issue without any other modifications and provide a code snippet that allows us to quickly reproduce the problem on `main`.

If you would like to submit a feature request, please check that the feature request does not already exist, and please provide the following information:

- The motivation for the feature
- A description of the problem you are trying to solve or your use case
- Enough information for us to understand the nature of the request
- Some information for how you intend to use it (this might help us in understanding the motivation!)

We can't promise to support every feature request, but it is helpful to us to know the use cases that you are interested in!

## Submitting a pull request

If you implemented support for a new robot or environment, or some other new feature, we welcome pull requests (PRs) to openpi. We encourage you to create a [feature request](https://github.com/Physical-Intelligence/openpi/issues) or make a post on the [discussion](https://github.com/Physical-Intelligence/openpi/discussions) board before starting to work on your PR, if you would like to get a sense for whether we are likely to approve your PR if it is submitted. Since we are a small team with limited ability to provide maintenance and support, we may not accept all PRs (e.g., if we believe it would make the code harder to maintain, or if reviewing the PR is out of scope for us), so contacting us in advance is a good way to get a sense for whether your PR is likely to get approved for merging into openpi directly. But even if it isn't, you are of course more than welcome to maintain your own fork with whatever modifications you would like. When creating PRs, we recommend every contribution to consider the following:

- Make sure that your PR has a clear title and description
- Run `pre-commit` (install using `pre-commit install` first), and run `ruff check .` and `ruff format .`
- Make sure your PR passes all tests


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

================================================
FILE: LICENSE_GEMMA.txt
================================================
Gemma Terms of Use 

Last modified: February 21, 2024

By using, reproducing, modifying, distributing, performing or displaying any portion or element of Gemma, Model Derivatives including via any Hosted Service, (each as defined below) (collectively, the "Gemma Services") or otherwise accepting the terms of this Agreement, you agree to be bound by this Agreement.

Section 1: DEFINITIONS
1.1 Definitions
(a) "Agreement" or "Gemma Terms of Use" means these terms and conditions that govern the use, reproduction, Distribution or modification of the Gemma Services and any terms and conditions incorporated by reference.

(b) "Distribution" or "Distribute" means any transmission, publication, or other sharing of Gemma or Model Derivatives to a third party, including by providing or making Gemma or its functionality available as a hosted service via API, web access, or any other electronic or remote means ("Hosted Service").

(c) "Gemma" means the set of machine learning language models, trained model weights and parameters identified at ai.google.dev/gemma, regardless of the source that you obtained it from.

(d) "Google" means Google LLC.

(e) "Model Derivatives" means all (i) modifications to Gemma, (ii) works based on Gemma, or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Gemma, to that model in order to cause that model to perform similarly to Gemma, including distillation methods that use intermediate data representations or methods based on the generation of synthetic data Outputs by Gemma for training that model. For clarity, Outputs are not deemed Model Derivatives.

(f) "Output" means the information content output of Gemma or a Model Derivative that results from operating or otherwise using Gemma or the Model Derivative, including via a Hosted Service.

1.2
As used in this Agreement, "including" means "including without limitation".

Section 2: ELIGIBILITY AND USAGE
2.1 Eligibility
You represent and warrant that you have the legal capacity to enter into this Agreement (including being of sufficient age of consent). If you are accessing or using any of the Gemma Services for or on behalf of a legal entity, (a) you are entering into this Agreement on behalf of yourself and that legal entity, (b) you represent and warrant that you have the authority to act on behalf of and bind that entity to this Agreement and (c) references to "you" or "your" in the remainder of this Agreement refers to both you (as an individual) and that entity.

2.2 Use
You may use, reproduce, modify, Distribute, perform or display any of the Gemma Services only in accordance with the terms of this Agreement, and must not violate (or encourage or permit anyone else to violate) any term of this Agreement.

Section 3: DISTRIBUTION AND RESTRICTIONS
3.1 Distribution and Redistribution
You may reproduce or Distribute copies of Gemma or Model Derivatives if you meet all of the following conditions:

You must include the use restrictions referenced in Section 3.2 as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Gemma or Model Derivatives and you must provide notice to subsequent users you Distribute to that Gemma or Model Derivatives are subject to the use restrictions in Section 3.2.
You must provide all third party recipients of Gemma or Model Derivatives a copy of this Agreement.
You must cause any modified files to carry prominent notices stating that you modified the files.
All Distributions (other than through a Hosted Service) must be accompanied by a "Notice" text file that contains the following notice: "Gemma is provided under and subject to the Gemma Terms of Use found at ai.google.dev/gemma/terms".
You may add your own intellectual property statement to your modifications and, except as set forth in this Section, may provide additional or different terms and conditions for use, reproduction, or Distribution of your modifications, or for any such Model Derivatives as a whole, provided your use, reproduction, modification, Distribution, performance, and display of Gemma otherwise complies with the terms and conditions of this Agreement. Any additional or different terms and conditions you impose must not conflict with the terms of this Agreement.

3.2 Use Restrictions
You must not use any of the Gemma Services:

for the restricted uses set forth in the Gemma Prohibited Use Policy at ai.google.dev/gemma/prohibited_use_policy ("Prohibited Use Policy"), which is hereby incorporated by reference into this Agreement; or
in violation of applicable laws and regulations.
To the maximum extent permitted by law, Google reserves the right to restrict (remotely or otherwise) usage of any of the Gemma Services that Google reasonably believes are in violation of this Agreement.

3.3 Generated Output
Google claims no rights in Outputs you generate using Gemma. You and your users are solely responsible for Outputs and their subsequent uses.

Section 4: ADDITIONAL PROVISIONS
4.1 Updates
Google may update Gemma from time to time, and you must make reasonable efforts to use the latest version of Gemma.

4.2 Trademarks
Nothing in this Agreement grants you any rights to use Google's trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between you and Google. Google reserves any rights not expressly granted herein.

4.3 DISCLAIMER OF WARRANTY
UNLESS REQUIRED BY APPLICABLE LAW, THE GEMMA SERVICES, AND OUTPUTS, ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY WARRANTIES OR CONDITIONS OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR OR DISTRIBUTING ANY OF THE GEMMA SERVICES OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR USE OR DISTRIBUTION OF ANY OF THE GEMMA SERVICES OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.

4.4 LIMITATION OF LIABILITY
TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), PRODUCT LIABILITY, CONTRACT, OR OTHERWISE, UNLESS REQUIRED BY APPLICABLE LAW, SHALL GOOGLE OR ITS AFFILIATES BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL, OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO, ANY OF THE GEMMA SERVICES OR OUTPUTS EVEN IF GOOGLE OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.

4.5 Term, Termination, and Survival
The term of this Agreement will commence upon your acceptance of this Agreement (including acceptance by your use, modification, or Distribution, reproduction, performance or display of any portion or element of the Gemma Services) and will continue in full force and effect until terminated in accordance with the terms of this Agreement. Google may terminate this Agreement if you are in breach of any term of this Agreement. Upon termination of this Agreement, you must delete and cease use and Distribution of all copies of Gemma and Model Derivatives in your possession or control. Sections 1, 2.1, 3.3, 4.2 to 4.9 shall survive the termination of this Agreement.

4.6 Governing Law and Jurisdiction
This Agreement will be governed by the laws of the State of California without regard to choice of law principles. The UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The state and federal courts of Santa Clara County, California shall have exclusive jurisdiction of any dispute arising out of this Agreement.

4.7 Severability
If any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.

4.8 Entire Agreement
This Agreement states all the terms agreed between the parties and supersedes all other agreements between the parties as of the date of acceptance relating to its subject matter.

4.9 No Waiver
Google will not be treated as having waived any rights by not exercising (or delaying the exercise of) any rights under this Agreement.

================================================
FILE: README.md
================================================
# openpi

openpi holds open-source models and packages for robotics, published by the [Physical Intelligence team](https://www.physicalintelligence.company/).

Currently, this repo contains three types of models:
- the [π₀ model](https://www.physicalintelligence.company/blog/pi0), a flow-based vision-language-action model (VLA).
- the [π₀-FAST model](https://www.physicalintelligence.company/research/fast), an autoregressive VLA, based on the FAST action tokenizer.
- the [π₀.₅ model](https://www.physicalintelligence.company/blog/pi05), an upgraded version of π₀ with better open-world generalization trained with [knowledge insulation](https://www.physicalintelligence.company/research/knowledge_insulation). Note that, in this repository, we currently only support the flow matching head for both $\pi_{0.5}$ training and inference.

For all models, we provide _base model_ checkpoints, pre-trained on 10k+ hours of robot data, and examples for using them out of the box or fine-tuning them to your own datasets.

This is an experiment: $\pi_0$ was developed for our own robots, which differ from the widely used platforms such as [ALOHA](https://tonyzhaozh.github.io/aloha/) and [DROID](https://droid-dataset.github.io/), and though we are optimistic that researchers and practitioners will be able to run creative new experiments adapting $\pi_0$ to their own platforms, we do not expect every such attempt to be successful. All this is to say: $\pi_0$ may or may not work for you, but you are welcome to try it and see!

## Updates

- [Sept 2025] We released PyTorch support in openpi.
- [Sept 2025] We released pi05, an upgraded version of pi0 with better open-world generalization.
- [Sept 2025]: We have added an [improved idle filter](examples/droid/README_train.md#data-filtering) for DROID training.
- [Jun 2025]: We have added [instructions](examples/droid/README_train.md) for using `openpi` to train VLAs on the full [DROID dataset](https://droid-dataset.github.io/). This is an approximate open-source implementation of the training pipeline used to train pi0-FAST-DROID. 


## Requirements

To run the models in this repository, you will need an NVIDIA GPU with at least the following specifications. These estimations assume a single GPU, but you can also use multiple GPUs with model parallelism to reduce per-GPU memory requirements by configuring `fsdp_devices` in the training config. Please also note that the current training script does not yet support multi-node training.

| Mode               | Memory Required | Example GPU        |
| ------------------ | --------------- | ------------------ |
| Inference          | > 8 GB          | RTX 4090           |
| Fine-Tuning (LoRA) | > 22.5 GB       | RTX 4090           |
| Fine-Tuning (Full) | > 70 GB         | A100 (80GB) / H100 |

The repo has been tested with Ubuntu 22.04, we do not currently support other operating systems.

## Installation

When cloning this repo, make sure to update submodules:

```bash
git clone --recurse-submodules git@github.com:Physical-Intelligence/openpi.git

# Or if you already cloned the repo:
git submodule update --init --recursive
```

We use [uv](https://docs.astral.sh/uv/) to manage Python dependencies. See the [uv installation instructions](https://docs.astral.sh/uv/getting-started/installation/) to set it up. Once uv is installed, run the following to set up the environment:

```bash
GIT_LFS_SKIP_SMUDGE=1 uv sync
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
```

NOTE: `GIT_LFS_SKIP_SMUDGE=1` is needed to pull LeRobot as a dependency.

**Docker**: As an alternative to uv installation, we provide instructions for installing openpi using Docker. If you encounter issues with your system setup, consider using Docker to simplify installation. See [Docker Setup](docs/docker.md) for more details.




## Model Checkpoints

### Base Models
We provide multiple base VLA model checkpoints. These checkpoints have been pre-trained on 10k+ hours of robot data, and can be used for fine-tuning.

| Model        | Use Case    | Description                                                                                                 | Checkpoint Path                                |
| ------------ | ----------- | ----------------------------------------------------------------------------------------------------------- | ---------------------------------------------- |
| $\pi_0$      | Fine-Tuning | Base [π₀ model](https://www.physicalintelligence.company/blog/pi0) for fine-tuning                | `gs://openpi-assets/checkpoints/pi0_base`      |
| $\pi_0$-FAST | Fine-Tuning | Base autoregressive [π₀-FAST model](https://www.physicalintelligence.company/research/fast) for fine-tuning | `gs://openpi-assets/checkpoints/pi0_fast_base` |
| $\pi_{0.5}$    | Fine-Tuning | Base [π₀.₅ model](https://www.physicalintelligence.company/blog/pi05) for fine-tuning    | `gs://openpi-assets/checkpoints/pi05_base`      |

### Fine-Tuned Models
We also provide "expert" checkpoints for various robot platforms and tasks. These models are fine-tuned from the base models above and intended to run directly on the target robot. These may or may not work on your particular robot. Since these checkpoints were fine-tuned on relatively small datasets collected with more widely available robots, such as ALOHA and the DROID Franka setup, they might not generalize to your particular setup, though we found some of these, especially the DROID checkpoint, to generalize quite broadly in practice.

| Model                    | Use Case    | Description                                                                                                                                                                                              | Checkpoint Path                                       |
| ------------------------ | ----------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------- |
| $\pi_0$-FAST-DROID       | Inference   | $\pi_0$-FAST model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/): can perform a wide range of simple table-top manipulation tasks 0-shot in new scenes on the DROID robot platform | `gs://openpi-assets/checkpoints/pi0_fast_droid`       |
| $\pi_0$-DROID            | Fine-Tuning | $\pi_0$ model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/): faster inference than $\pi_0$-FAST-DROID, but may not follow language commands as well                                | `gs://openpi-assets/checkpoints/pi0_droid`            |
| $\pi_0$-ALOHA-towel      | Inference   | $\pi_0$ model fine-tuned on internal [ALOHA](https://tonyzhaozh.github.io/aloha/) data: can fold diverse towels 0-shot on ALOHA robot platforms                                                          | `gs://openpi-assets/checkpoints/pi0_aloha_towel`      |
| $\pi_0$-ALOHA-tupperware | Inference   | $\pi_0$ model fine-tuned on internal [ALOHA](https://tonyzhaozh.github.io/aloha/) data: can unpack food from a tupperware container                                                                                                             | `gs://openpi-assets/checkpoints/pi0_aloha_tupperware` |
| $\pi_0$-ALOHA-pen-uncap  | Inference   | $\pi_0$ model fine-tuned on public [ALOHA](https://dit-policy.github.io/) data: can uncap a pen                                                                                                          | `gs://openpi-assets/checkpoints/pi0_aloha_pen_uncap`  |
| $\pi_{0.5}$-LIBERO      | Inference   | $\pi_{0.5}$ model fine-tuned for the [LIBERO](https://libero-project.github.io/datasets) benchmark: gets state-of-the-art performance (see [LIBERO README](examples/libero/README.md)) | `gs://openpi-assets/checkpoints/pi05_libero`      |
| $\pi_{0.5}$-DROID      | Inference / Fine-Tuning | $\pi_{0.5}$ model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/) with [knowledge insulation](https://www.physicalintelligence.company/research/knowledge_insulation): fast inference and good language-following | `gs://openpi-assets/checkpoints/pi05_droid`      |


By default, checkpoints are automatically downloaded from `gs://openpi-assets` and are cached in `~/.cache/openpi` when needed. You can overwrite the download path by setting the `OPENPI_DATA_HOME` environment variable.




## Running Inference for a Pre-Trained Model

Our pre-trained model checkpoints can be run with a few lines of code (here our $\pi_0$-FAST-DROID model):
```python
from openpi.training import config as _config
from openpi.policies import policy_config
from openpi.shared import download

config = _config.get_config("pi05_droid")
checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi05_droid")

# Create a trained policy.
policy = policy_config.create_trained_policy(config, checkpoint_dir)

# Run inference on a dummy example.
example = {
    "observation/exterior_image_1_left": ...,
    "observation/wrist_image_left": ...,
    ...
    "prompt": "pick up the fork"
}
action_chunk = policy.infer(example)["actions"]
```
You can also test this out in the [example notebook](examples/inference.ipynb).

We provide detailed step-by-step examples for running inference of our pre-trained checkpoints on [DROID](examples/droid/README.md) and [ALOHA](examples/aloha_real/README.md) robots.

**Remote Inference**: We provide [examples and code](docs/remote_inference.md) for running inference of our models **remotely**: the model can run on a different server and stream actions to the robot via a websocket connection. This makes it easy to use more powerful GPUs off-robot and keep robot and policy environments separate.

**Test inference without a robot**: We provide a [script](examples/simple_client/README.md) for testing inference without a robot. This script will generate a random observation and run inference with the model. See [here](examples/simple_client/README.md) for more details.





## Fine-Tuning Base Models on Your Own Data

We will fine-tune the $\pi_{0.5}$ model on the [LIBERO dataset](https://libero-project.github.io/datasets) as a running example for how to fine-tune a base model on your own data. We will explain three steps:
1. Convert your data to a LeRobot dataset (which we use for training)
2. Defining training configs and running training
3. Spinning up a policy server and running inference

### 1. Convert your data to a LeRobot dataset

We provide a minimal example script for converting LIBERO data to a LeRobot dataset in [`examples/libero/convert_libero_data_to_lerobot.py`](examples/libero/convert_libero_data_to_lerobot.py). You can easily modify it to convert your own data! You can download the raw LIBERO dataset from [here](https://huggingface.co/datasets/openvla/modified_libero_rlds), and run the script with:

```bash
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/libero/data
```

**Note:** If you just want to fine-tune on LIBERO, you can skip this step, because our LIBERO fine-tuning configs point to a pre-converted LIBERO dataset. This step is merely an example that you can adapt to your own data.

### 2. Defining training configs and running training

To fine-tune a base model on your own data, you need to define configs for data processing and training. We provide example configs with detailed comments for LIBERO below, which you can modify for your own dataset:

- [`LiberoInputs` and `LiberoOutputs`](src/openpi/policies/libero_policy.py): Defines the data mapping from the LIBERO environment to the model and vice versa. Will be used for both, training and inference.
- [`LeRobotLiberoDataConfig`](src/openpi/training/config.py): Defines how to process raw LIBERO data from LeRobot dataset for training.
- [`TrainConfig`](src/openpi/training/config.py): Defines fine-tuning hyperparameters, data config, and weight loader.

We provide example fine-tuning configs for [π₀](src/openpi/training/config.py), [π₀-FAST](src/openpi/training/config.py), and [π₀.₅](src/openpi/training/config.py) on LIBERO data.

Before we can run training, we need to compute the normalization statistics for the training data. Run the script below with the name of your training config:

```bash
uv run scripts/compute_norm_stats.py --config-name pi05_libero
```

Now we can kick off training with the following command (the `--overwrite` flag is used to overwrite existing checkpoints if you rerun fine-tuning with the same config):

```bash
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_libero --exp-name=my_experiment --overwrite
```

The command will log training progress to the console and save checkpoints to the `checkpoints` directory. You can also monitor training progress on the Weights & Biases dashboard. For maximally using the GPU memory, set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` before running training -- this enables JAX to use up to 90% of the GPU memory (vs. the default of 75%).

**Note:** We provide functionality for *reloading* normalization statistics for state / action normalization from pre-training. This can be beneficial if you are fine-tuning to a new task on a robot that was part of our pre-training mixture. For more details on how to reload normalization statistics, see the [norm_stats.md](docs/norm_stats.md) file.

### 3. Spinning up a policy server and running inference

Once training is complete, we can run inference by spinning up a policy server and then querying it from a LIBERO evaluation script. Launching a model server is easy (we use the checkpoint for iteration 20,000 for this example, modify as needed):

```bash
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_libero --policy.dir=checkpoints/pi05_libero/my_experiment/20000
```

This will spin up a server that listens on port 8000 and waits for observations to be sent to it. We can then run an evaluation script (or robot runtime) that queries the server.

For running the LIBERO eval in particular, we provide (and recommend using) a Dockerized workflow that handles both the policy server and the evaluation script together. See the [LIBERO README](examples/libero/README.md) for more details.

If you want to embed a policy server call in your own robot runtime, we have a minimal example of how to do so in the [remote inference docs](docs/remote_inference.md).



### More Examples

We provide more examples for how to fine-tune and run inference with our models on the ALOHA platform in the following READMEs:
- [ALOHA Simulator](examples/aloha_sim)
- [ALOHA Real](examples/aloha_real)
- [UR5](examples/ur5)

## PyTorch Support

openpi now provides PyTorch implementations of π₀ and π₀.₅ models alongside the original JAX versions! The PyTorch implementation has been validated on the LIBERO benchmark (both inference and finetuning). A few features are currently not supported (this may change in the future):

- The π₀-FAST model
- Mixed precision training
- FSDP (fully-sharded data parallelism) training
- LoRA (low-rank adaptation) training
- EMA (exponential moving average) weights during training

### Setup
1. Make sure that you have the latest version of all dependencies installed: `uv sync`

2. Double check that you have transformers 4.53.2 installed: `uv pip show transformers`

3. Apply the transformers library patches:
   ```bash
   cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/
   ```

This overwrites several files in the transformers library with necessary model changes: 1) supporting AdaRMS, 2) correctly controlling the precision of activations, and 3) allowing the KV cache to be used without being updated.

**WARNING**: With the default uv link mode (hardlink), this will permanently affect the transformers library in your uv cache, meaning the changes will survive reinstallations of transformers and could even propagate to other projects that use transformers. To fully undo this operation, you must run `uv cache clean transformers`.

### Converting JAX Models to PyTorch

To convert a JAX model checkpoint to PyTorch format:

```bash
uv run examples/convert_jax_model_to_pytorch.py \
    --checkpoint_dir /path/to/jax/checkpoint \
    --config_name <config name> \
    --output_path /path/to/converted/pytorch/checkpoint
```

### Running Inference with PyTorch

The PyTorch implementation uses the same API as the JAX version - you only need to change the checkpoint path to point to the converted PyTorch model:

```python
from openpi.training import config as _config
from openpi.policies import policy_config
from openpi.shared import download

config = _config.get_config("pi05_droid")
checkpoint_dir = "/path/to/converted/pytorch/checkpoint"

# Create a trained policy (automatically detects PyTorch format)
policy = policy_config.create_trained_policy(config, checkpoint_dir)

# Run inference (same API as JAX)
action_chunk = policy.infer(example)["actions"]
```

### Policy Server with PyTorch

The policy server works identically with PyTorch models - just point to the converted checkpoint directory:

```bash
uv run scripts/serve_policy.py policy:checkpoint \
    --policy.config=pi05_droid \
    --policy.dir=/path/to/converted/pytorch/checkpoint
```

### Finetuning with PyTorch

To finetune a model in PyTorch:

1. Convert the JAX base model to PyTorch format:
   ```bash
   uv run examples/convert_jax_model_to_pytorch.py \
       --config_name <config name> \
       --checkpoint_dir /path/to/jax/base/model \
       --output_path /path/to/pytorch/base/model
   ```

2. Specify the converted PyTorch model path in your config using `pytorch_weight_path`

3. Launch training using one of these modes:

```bash
# Single GPU training:
uv run scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>

# Example:
uv run scripts/train_pytorch.py debug --exp_name pytorch_test
uv run scripts/train_pytorch.py debug --exp_name pytorch_test --resume  # Resume from latest checkpoint

# Multi-GPU training (single node):
uv run torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name>

# Example:
uv run torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
uv run torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume

# Multi-Node Training:
uv run torchrun \
    --nnodes=<num_nodes> \
    --nproc_per_node=<gpus_per_node> \
    --node_rank=<rank_of_node> \
    --master_addr=<master_ip> \
    --master_port=<port> \
    scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>
```

### Precision Settings

JAX and PyTorch implementations handle precision as follows:

**JAX:**
1. Inference: most weights and computations in bfloat16, with a few computations in float32 for stability
2. Training: defaults to mixed precision: weights and gradients in float32, (most) activations and computations in bfloat16. You can change to full float32 training by setting `dtype` to float32 in the config.

**PyTorch:**
1. Inference: matches JAX -- most weights and computations in bfloat16, with a few weights converted to float32 for stability
2. Training: supports either full bfloat16 (default) or full float32. You can change it by setting `pytorch_training_precision` in the config. bfloat16 uses less memory but exhibits higher losses compared to float32. Mixed precision is not yet supported.

With torch.compile, inference speed is comparable between JAX and PyTorch.

## Troubleshooting

We will collect common issues and their solutions here. If you encounter an issue, please check here first. If you can't find a solution, please file an issue on the repo (see [here](CONTRIBUTING.md) for guidelines).

| Issue                                     | Resolution                                                                                                                                                                                   |
| ----------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `uv sync` fails with dependency conflicts | Try removing the virtual environment directory (`rm -rf .venv`) and running `uv sync` again. If issues persist, check that you have the latest version of `uv` installed (`uv self update`). |
| Training runs out of GPU memory           | Make sure you set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` (or higher) before running training to allow JAX to use more GPU memory. You can also use `--fsdp-devices <n>` where `<n>` is your number of GPUs, to enable [fully-sharded data parallelism](https://engineering.fb.com/2021/07/15/open-source/fsdp/), which reduces memory usage in exchange for slower training (the amount of slowdown depends on your particular setup). If you are still running out of memory, you may want to consider disabling EMA.        |
| Policy server connection errors           | Check that the server is running and listening on the expected port. Verify network connectivity and firewall settings between client and server.                                            |
| Missing norm stats error when training    | Run `scripts/compute_norm_stats.py` with your config name before starting training.                                                                                                          |
| Dataset download fails                    | Check your internet connection. For HuggingFace datasets, ensure you're logged in (`huggingface-cli login`).                                                                                 |
| CUDA/GPU errors                           | Verify NVIDIA drivers are installed correctly. For Docker, ensure nvidia-container-toolkit is installed. Check GPU compatibility. You do NOT need CUDA libraries installed at a system level --- they will be installed via uv. You may even want to try *uninstalling* system CUDA libraries if you run into CUDA issues, since system libraries can sometimes cause conflicts. |
| Import errors when running examples       | Make sure you've installed all dependencies with `uv sync`. Some examples may have additional requirements listed in their READMEs.                    |
| Action dimensions mismatch                | Verify your data processing transforms match the expected input/output dimensions of your robot. Check the action space definitions in your policy classes.                                  |
| Diverging training loss                            | Check the `q01`, `q99`, and `std` values in `norm_stats.json` for your dataset. Certain dimensions that are rarely used can end up with very small `q01`, `q99`, or `std` values, leading to huge states and actions after normalization. You can manually adjust the norm stats as a workaround. |


================================================
FILE: docs/docker.md
================================================
### Docker Setup

All of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS.

- Basic Docker installation instructions are [here](https://docs.docker.com/engine/install/).
- Docker must be installed in [rootless mode](https://docs.docker.com/engine/security/rootless/).
- To use your GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
- The version of docker installed with `snap` is incompatible with the NVIDIA container toolkit, preventing it from accessing `libnvidia-ml.so` ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/154)). The snap version can be uninstalled with `sudo snap remove docker`.
- Docker Desktop is also incompatible with the NVIDIA runtime ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/229)). Docker Desktop can be uninstalled with `sudo apt remove docker-desktop`.


If starting from scratch and your host machine is Ubuntu 22.04, you can use accomplish all of the above with the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`.

Build the Docker image and start the container with the following command:
```bash
docker compose -f scripts/docker/compose.yml up --build
```

To build and run the Docker image for a specific example, use the following command:
```bash
docker compose -f examples/<example_name>/compose.yml up --build
```
where `<example_name>` is the name of the example you want to run.

During the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached.

================================================
FILE: docs/norm_stats.md
================================================
# Normalization statistics

Following common practice, our models normalize the proprioceptive state inputs and action targets during policy training and inference. The statistics used for normalization are computed over the training data and stored alongside the model checkpoint.

## Reloading normalization statistics

When you fine-tune one of our models on a new dataset, you need to decide whether to (A) reuse existing normalization statistics or (B) compute new statistics over your new training data. Which option is better for you depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. Below, we list all the available pre-training normalization statistics for each model.

**If your target robot matches one of these pre-training statistics, consider reloading the same normalization statistics.** By reloading the normalization statistics, the actions in your dataset will be more "familiar" to the model, which can lead to better performance. You can reload the normalization statistics by adding an `AssetsConfig` to your training config that points to the corresponding checkpoint directory and normalization statistics ID, like below for the `Trossen` (aka ALOHA) robot statistics of the `pi0_base` checkpoint:

```python
TrainConfig(
    ...
    data=LeRobotAlohaDataConfig(
        ...
        assets=AssetsConfig(
            assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
            asset_id="trossen",
        ),
    ),
)
```

For an example of a full training config that reloads normalization statistics, see the `pi0_aloha_pen_uncap` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).

**Note:** To successfully reload normalization statistics, it's important that your robot + dataset are following the action space definitions used in pre-training. We provide a detailed description of our action space definitions below.

**Note #2:** Whether reloading normalization statistics is beneficial depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. We recommend to always try both, reloading and training with a fresh set of statistics computed on your new dataset (see [main README](../README.md) for instructions on how to compute new statistics), and pick the one that works better for your task.


## Provided Pre-training Normalization Statistics

Below is a list of all the pre-training normalization statistics we provide. We provide them for both, the `pi0_base` and `pi0_fast_base` models. For `pi0_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_base/assets` and for `pi0_fast_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_fast_base/assets`.
| Robot | Description | Asset ID |
|-------|-------------|----------|
| ALOHA | 6-DoF dual arm robot with parallel grippers | trossen |
| Mobile ALOHA | Mobile version of ALOHA mounted on a Slate base | trossen_mobile |
| Franka Emika (DROID) | 7-DoF arm with parallel gripper based on the DROID setup | droid |
| Franka Emika (non-DROID) | Franka FR3 arm with Robotiq 2F-85 gripper | franka |
| UR5e | 6-DoF UR5e arm with Robotiq 2F-85 gripper | ur5e |
| UR5e bi-manual | Bi-manual UR5e setup with Robotiq 2F-85 grippers | ur5e_dual |
| ARX | Bi-manual ARX-5 robot arm setup with parallel gripper | arx |
| ARX mobile | Mobile version of bi-manual ARX-5 robot arm setup mounted on a Slate base | arx_mobile |
| Fibocom mobile | Fibocom mobile robot with 2x ARX-5 arms | fibocom_mobile |


## Pi0 Model Action Space Definitions

Out of the box, both the `pi0_base` and `pi0_fast_base` use the following action space definitions (left and right are defined looking from behind the robot towards the workspace):
```
    "dim_0:dim_5": "left arm joint angles",
    "dim_6": "left arm gripper position",
    "dim_7:dim_12": "right arm joint angles (for bi-manual only)",
    "dim_13": "right arm gripper position (for bi-manual only)",

    # For mobile robots:
    "dim_14:dim_15": "x-y base velocity (for mobile robots only)",
```

The proprioceptive state uses the same definitions as the action space, except for the base x-y position (the last two dimensions) for mobile robots, which we don't include in the proprioceptive state.

For 7-DoF robots (e.g. Franka), we use the first 7 dimensions of the action space for the joint actions, and the 8th dimension for the gripper action.

General info for Pi robots:
- Joint angles are expressed in radians, with position zero corresponding to the zero position reported by each robot's interface library, except for ALOHA, where the standard ALOHA code uses a slightly different convention (see the [ALOHA example code](../examples/aloha_real/README.md) for details).
- Gripper positions are in [0.0, 1.0], with 0.0 corresponding to fully open and 1.0 corresponding to fully closed.
- Control frequencies are either 20 Hz for UR5e and Franka, and 50 Hz for ARX and Trossen (ALOHA) arms.

For DROID, we use the original DROID action configuration, with joint velocity actions in the first 7 dimensions and gripper actions in the 8th dimension + a control frequency of 15 Hz.


================================================
FILE: docs/remote_inference.md
================================================

# Running openpi models remotely

We provide utilities for running openpi models remotely. This is useful for running inference on more powerful GPUs off-robot, and also helps keep the robot and policy environments separate (and e.g. avoid dependency hell with robot software).

## Starting a remote policy server

To start a remote policy server, you can simply run the following command:

```bash
uv run scripts/serve_policy.py --env=[DROID | ALOHA | LIBERO]
```

The `env` argument specifies which $\pi_0$ checkpoint should be loaded. Under the hood, this script will execute a command like the following, which you can use to start a policy server, e.g. for checkpoints you trained yourself (here an example for the DROID environment):

```bash
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid
```

This will start a policy server that will serve the policy specified by the `config` and `dir` arguments. The policy will be served on the specified port (default: 8000).

## Querying the remote policy server from your robot code

We provide a client utility with minimal dependencies that you can easily embed into any robot codebase.

First, install the `openpi-client` package in your robot environment:

```bash
cd $OPENPI_ROOT/packages/openpi-client
pip install -e .
```

Then, you can use the client to query the remote policy server from your robot code. Here's an example of how to do this:

```python
from openpi_client import image_tools
from openpi_client import websocket_client_policy

# Outside of episode loop, initialize the policy client.
# Point to the host and port of the policy server (localhost and 8000 are the defaults).
client = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000)

for step in range(num_steps):
    # Inside the episode loop, construct the observation.
    # Resize images on the client side to minimize bandwidth / latency. Always return images in uint8 format.
    # We provide utilities for resizing images + uint8 conversion so you match the training routines.
    # The typical resize_size for pre-trained pi0 models is 224.
    # Note that the proprioceptive `state` can be passed unnormalized, normalization will be handled on the server side.
    observation = {
        "observation/image": image_tools.convert_to_uint8(
            image_tools.resize_with_pad(img, 224, 224)
        ),
        "observation/wrist_image": image_tools.convert_to_uint8(
            image_tools.resize_with_pad(wrist_img, 224, 224)
        ),
        "observation/state": state,
        "prompt": task_instruction,
    }

    # Call the policy server with the current observation.
    # This returns an action chunk of shape (action_horizon, action_dim).
    # Note that you typically only need to call the policy every N steps and execute steps
    # from the predicted action chunk open-loop in the remaining steps.
    action_chunk = client.infer(observation)["actions"]

    # Execute the actions in the environment.
    ...

```

Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `observation` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](../examples/simple_client/main.py).


================================================
FILE: examples/aloha_real/Dockerfile
================================================
# Dockerfile for the Aloha real environment.

# Build the container:
# docker build . -t aloha_real -f examples/aloha_real/Dockerfile

# Run the container:
# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash

FROM ros:noetic-robot@sha256:7cf0b9f6546abeba308ea42cb7ad3453f3e520e1af57cdf179fe915c939674bc
SHELL ["/bin/bash", "-c"]

ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && \
    apt-get install -y --no-install-recommends \
    cmake \
    curl \
    libffi-dev \
    python3-rosdep \
    python3-rosinstall \
    python3-rosinstall-generator \
    whiptail \
    git \
    wget \
    openssh-client \
    ros-noetic-cv-bridge \
    ros-noetic-usb-cam \
    ros-noetic-realsense2-camera \
    keyboard-configuration

WORKDIR /root
RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
RUN chmod +x xsarm_amd64_install.sh
RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n

COPY ./third_party/aloha /root/interbotix_ws/src/aloha
RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make

# Install python 3.10 because this ROS image comes with 3.8
RUN mkdir /python && \
    cd /python && \
    wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
    tar -zxvf Python-3.10.14.tgz && \
    cd Python-3.10.14 && \
    ls -lhR && \
    ./configure --enable-optimizations && \
    make install && \
    echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
    echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
    cd ~ && rm -rf /python && \
    rm -rf /var/lib/apt/lists/*

COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
ENV UV_HTTP_TIMEOUT=120
ENV UV_LINK_MODE=copy
COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml

ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
WORKDIR /app

# Create an entrypoint script to run the setup commands, followed by the command passed in.
RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
#!/bin/bash
source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
EOF
RUN chmod +x /usr/local/bin/entrypoint.sh

ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
CMD ["python3", "/app/examples/aloha_real/main.py"]


================================================
FILE: examples/aloha_real/README.md
================================================
# Run Aloha (Real Robot)

This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.

## Prerequisites

This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.

1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.
1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.

## With Docker

```bash
export SERVER_ARGS="--env ALOHA --default_prompt='take the toast out of the toaster'"
docker compose -f examples/aloha_real/compose.yml up --build
```

## Without Docker

Terminal window 1:

```bash
# Create virtual environment
uv venv --python 3.10 examples/aloha_real/.venv
source examples/aloha_real/.venv/bin/activate
uv pip sync examples/aloha_real/requirements.txt
uv pip install -e packages/openpi-client

# Run the robot
python -m examples.aloha_real.main
```

Terminal window 2:

```bash
roslaunch aloha ros_nodes.launch
```

Terminal window 3:

```bash
uv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster'
```

## **ALOHA Checkpoint Guide**


The `pi0_base` model can be used in zero shot for a simple task on the ALOHA platform, and we additionally provide two example fine-tuned checkpoints, “fold the towel” and “open the tupperware and put the food on the plate,” which can perform more advanced tasks on the ALOHA.

While we’ve found the policies to work in unseen conditions across multiple ALOHA stations, we provide some pointers here on how best to set up scenes to maximize the chance of policy success. We cover the prompts to use for the policies, objects we’ve seen it work well on, and well-represented initial state distributions. Running these policies in zero shot is still a very experimental feature, and there is no guarantee that they will work on your robot. The recommended way to use `pi0_base` is by finetuning with data from the target robot.


---

### **Toast Task**

This task involves the robot taking two pieces of toast out of a toaster and placing them on a plate.

- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_base`
- **Prompt**: "take the toast out of the toaster"
- **Objects needed**: Two pieces of toast, a plate, and a standard toaster.
- **Object Distribution**:
  - Works on both real toast and rubber fake toast
  - Compatible with standard 2-slice toasters
  - Works with plates of varying colors

### **Scene Setup Guidelines**
<img width="500" alt="Screenshot 2025-01-31 at 10 06 02 PM" src="https://github.com/user-attachments/assets/3d043d95-9d1c-4dda-9991-e63cae61e02e" />

- The toaster should be positioned in the top-left quadrant of the workspace.
- Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top.
- The plate should be placed roughly in the lower-center of the workspace.
- Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain).


### **Towel Task**

This task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths.

- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_towel`
- **Prompt**: "fold the towel"
- **Object Distribution**:
  - Works on towels of varying solid colors
  - Performance is worse on heavily textured or striped towels

### **Scene Setup Guidelines**
<img width="500" alt="Screenshot 2025-01-31 at 10 01 15 PM" src="https://github.com/user-attachments/assets/9410090c-467d-4a9c-ac76-96e5b4d00943" />

- The towel should be flattened and roughly centered on the table.
- Choose a towel that does not blend in with the table surface.


### **Tupperware Task**

This task involves opening a tupperware filled with food and pouring the contents onto a plate.

- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_tupperware`
- **Prompt**: "open the tupperware and put the food on the plate"
- **Objects needed**: Tupperware, food (or food-like items), and a plate.
- **Object Distribution**:
  - Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken).
  - Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below).
  - The policy has seen plates of varying solid colors.

### **Scene Setup Guidelines**
<img width="500" alt="Screenshot 2025-01-31 at 10 02 27 PM" src="https://github.com/user-attachments/assets/60fc1de0-2d64-4076-b903-f427e5e9d1bf" />

- Best performance observed when both the tupperware and plate are roughly centered in the workspace.
- Positioning:
  - Tupperware should be on the left.
  - Plate should be on the right or bottom.
  - The tupperware flap should point toward the plate.

## Training on your own Aloha dataset

1. Convert the dataset to the LeRobot dataset v2.0 format.

    We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse).


2. Define a training config that uses the custom dataset.

    We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.

IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig.


================================================
FILE: examples/aloha_real/compose.yml
================================================
# Run with:
# docker compose -f examples/aloha_real/compose.yml up --build
services:
  runtime:
    image: aloha_real
    depends_on:
      - aloha_ros_nodes
      - ros_master
      - openpi_server
    build:
      context: ../..
      dockerfile: examples/aloha_real/Dockerfile
    init: true
    tty: true
    network_mode: host
    privileged: true
    volumes:
      - $PWD:/app
      - ../../data:/data

  aloha_ros_nodes:
    image: aloha_real
    depends_on:
      - ros_master
    build:
      context: ../..
      dockerfile: examples/aloha_real/Dockerfile
    init: true
    tty: true
    network_mode: host
    privileged: true
    volumes:
      - /dev:/dev
    command: roslaunch --wait aloha ros_nodes.launch

  ros_master:
    image: ros:noetic-robot
    network_mode: host
    privileged: true
    command:
      - roscore

  openpi_server:
    image: openpi_server
    build:
      context: ../..
      dockerfile: scripts/docker/serve_policy.Dockerfile
    init: true
    tty: true
    network_mode: host
    volumes:
      - $PWD:/app
      - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
    environment:
      - SERVER_ARGS
      - OPENPI_DATA_HOME=/openpi_assets
      - IS_DOCKER=true

    # Comment out this block if not running on a machine with GPUs.
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              count: 1
              capabilities: [gpu]


================================================
FILE: examples/aloha_real/constants.py
================================================
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
# ruff: noqa

### Task parameters

### ALOHA fixed constants
DT = 0.001
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]

# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
MASTER_GRIPPER_POSITION_OPEN = 0.02417
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844

# Gripper joint limits (qpos[6])
MASTER_GRIPPER_JOINT_OPEN = 0.3083
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213

############################ Helper functions ############################

MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
    MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
)
PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
    PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
)
MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
    lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
)
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
    lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
)
MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))

MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
    MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
)
PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
    PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
)
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
    lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
)
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
    lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
)
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))

MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)

MASTER_POS2JOINT = (
    lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
    + MASTER_GRIPPER_JOINT_CLOSE
)
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
    (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
)
PUPPET_POS2JOINT = (
    lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
    + PUPPET_GRIPPER_JOINT_CLOSE
)
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
    (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
)

MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2


================================================
FILE: examples/aloha_real/convert_aloha_data_to_lerobot.py
================================================
"""
Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.

Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
"""

import dataclasses
from pathlib import Path
import shutil
from typing import Literal

import h5py
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
import numpy as np
import torch
import tqdm
import tyro


@dataclasses.dataclass(frozen=True)
class DatasetConfig:
    use_videos: bool = True
    tolerance_s: float = 0.0001
    image_writer_processes: int = 10
    image_writer_threads: int = 5
    video_backend: str | None = None


DEFAULT_DATASET_CONFIG = DatasetConfig()


def create_empty_dataset(
    repo_id: str,
    robot_type: str,
    mode: Literal["video", "image"] = "video",
    *,
    has_velocity: bool = False,
    has_effort: bool = False,
    dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
) -> LeRobotDataset:
    motors = [
        "right_waist",
        "right_shoulder",
        "right_elbow",
        "right_forearm_roll",
        "right_wrist_angle",
        "right_wrist_rotate",
        "right_gripper",
        "left_waist",
        "left_shoulder",
        "left_elbow",
        "left_forearm_roll",
        "left_wrist_angle",
        "left_wrist_rotate",
        "left_gripper",
    ]
    cameras = [
        "cam_high",
        "cam_low",
        "cam_left_wrist",
        "cam_right_wrist",
    ]

    features = {
        "observation.state": {
            "dtype": "float32",
            "shape": (len(motors),),
            "names": [
                motors,
            ],
        },
        "action": {
            "dtype": "float32",
            "shape": (len(motors),),
            "names": [
                motors,
            ],
        },
    }

    if has_velocity:
        features["observation.velocity"] = {
            "dtype": "float32",
            "shape": (len(motors),),
            "names": [
                motors,
            ],
        }

    if has_effort:
        features["observation.effort"] = {
            "dtype": "float32",
            "shape": (len(motors),),
            "names": [
                motors,
            ],
        }

    for cam in cameras:
        features[f"observation.images.{cam}"] = {
            "dtype": mode,
            "shape": (3, 480, 640),
            "names": [
                "channels",
                "height",
                "width",
            ],
        }

    if Path(LEROBOT_HOME / repo_id).exists():
        shutil.rmtree(LEROBOT_HOME / repo_id)

    return LeRobotDataset.create(
        repo_id=repo_id,
        fps=50,
        robot_type=robot_type,
        features=features,
        use_videos=dataset_config.use_videos,
        tolerance_s=dataset_config.tolerance_s,
        image_writer_processes=dataset_config.image_writer_processes,
        image_writer_threads=dataset_config.image_writer_threads,
        video_backend=dataset_config.video_backend,
    )


def get_cameras(hdf5_files: list[Path]) -> list[str]:
    with h5py.File(hdf5_files[0], "r") as ep:
        # ignore depth channel, not currently handled
        return [key for key in ep["/observations/images"].keys() if "depth" not in key]  # noqa: SIM118


def has_velocity(hdf5_files: list[Path]) -> bool:
    with h5py.File(hdf5_files[0], "r") as ep:
        return "/observations/qvel" in ep


def has_effort(hdf5_files: list[Path]) -> bool:
    with h5py.File(hdf5_files[0], "r") as ep:
        return "/observations/effort" in ep


def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
    imgs_per_cam = {}
    for camera in cameras:
        uncompressed = ep[f"/observations/images/{camera}"].ndim == 4

        if uncompressed:
            # load all images in RAM
            imgs_array = ep[f"/observations/images/{camera}"][:]
        else:
            import cv2

            # load one compressed image after the other in RAM and uncompress
            imgs_array = []
            for data in ep[f"/observations/images/{camera}"]:
                imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB))
            imgs_array = np.array(imgs_array)

        imgs_per_cam[camera] = imgs_array
    return imgs_per_cam


def load_raw_episode_data(
    ep_path: Path,
) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
    with h5py.File(ep_path, "r") as ep:
        state = torch.from_numpy(ep["/observations/qpos"][:])
        action = torch.from_numpy(ep["/action"][:])

        velocity = None
        if "/observations/qvel" in ep:
            velocity = torch.from_numpy(ep["/observations/qvel"][:])

        effort = None
        if "/observations/effort" in ep:
            effort = torch.from_numpy(ep["/observations/effort"][:])

        imgs_per_cam = load_raw_images_per_camera(
            ep,
            [
                "cam_high",
                "cam_low",
                "cam_left_wrist",
                "cam_right_wrist",
            ],
        )

    return imgs_per_cam, state, action, velocity, effort


def populate_dataset(
    dataset: LeRobotDataset,
    hdf5_files: list[Path],
    task: str,
    episodes: list[int] | None = None,
) -> LeRobotDataset:
    if episodes is None:
        episodes = range(len(hdf5_files))

    for ep_idx in tqdm.tqdm(episodes):
        ep_path = hdf5_files[ep_idx]

        imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
        num_frames = state.shape[0]

        for i in range(num_frames):
            frame = {
                "observation.state": state[i],
                "action": action[i],
            }

            for camera, img_array in imgs_per_cam.items():
                frame[f"observation.images.{camera}"] = img_array[i]

            if velocity is not None:
                frame["observation.velocity"] = velocity[i]
            if effort is not None:
                frame["observation.effort"] = effort[i]

            dataset.add_frame(frame)

        dataset.save_episode(task=task)

    return dataset


def port_aloha(
    raw_dir: Path,
    repo_id: str,
    raw_repo_id: str | None = None,
    task: str = "DEBUG",
    *,
    episodes: list[int] | None = None,
    push_to_hub: bool = True,
    is_mobile: bool = False,
    mode: Literal["video", "image"] = "image",
    dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
):
    if (LEROBOT_HOME / repo_id).exists():
        shutil.rmtree(LEROBOT_HOME / repo_id)

    if not raw_dir.exists():
        if raw_repo_id is None:
            raise ValueError("raw_repo_id must be provided if raw_dir does not exist")
        download_raw(raw_dir, repo_id=raw_repo_id)

    hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))

    dataset = create_empty_dataset(
        repo_id,
        robot_type="mobile_aloha" if is_mobile else "aloha",
        mode=mode,
        has_effort=has_effort(hdf5_files),
        has_velocity=has_velocity(hdf5_files),
        dataset_config=dataset_config,
    )
    dataset = populate_dataset(
        dataset,
        hdf5_files,
        task=task,
        episodes=episodes,
    )
    dataset.consolidate()

    if push_to_hub:
        dataset.push_to_hub()


if __name__ == "__main__":
    tyro.cli(port_aloha)


================================================
FILE: examples/aloha_real/env.py
================================================
from typing import List, Optional  # noqa: UP035

import einops
from openpi_client import image_tools
from openpi_client.runtime import environment as _environment
from typing_extensions import override

from examples.aloha_real import real_env as _real_env


class AlohaRealEnvironment(_environment.Environment):
    """An environment for an Aloha robot on real hardware."""

    def __init__(
        self,
        reset_position: Optional[List[float]] = None,  # noqa: UP006,UP007
        render_height: int = 224,
        render_width: int = 224,
    ) -> None:
        self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position)
        self._render_height = render_height
        self._render_width = render_width

        self._ts = None

    @override
    def reset(self) -> None:
        self._ts = self._env.reset()

    @override
    def is_episode_complete(self) -> bool:
        return False

    @override
    def get_observation(self) -> dict:
        if self._ts is None:
            raise RuntimeError("Timestep is not set. Call reset() first.")

        obs = self._ts.observation
        for k in list(obs["images"].keys()):
            if "_depth" in k:
                del obs["images"][k]

        for cam_name in obs["images"]:
            img = image_tools.convert_to_uint8(
                image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width)
            )
            obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w")

        return {
            "state": obs["qpos"],
            "images": obs["images"],
        }

    @override
    def apply_action(self, action: dict) -> None:
        self._ts = self._env.step(action["actions"])


================================================
FILE: examples/aloha_real/main.py
================================================
import dataclasses
import logging

from openpi_client import action_chunk_broker
from openpi_client import websocket_client_policy as _websocket_client_policy
from openpi_client.runtime import runtime as _runtime
from openpi_client.runtime.agents import policy_agent as _policy_agent
import tyro

from examples.aloha_real import env as _env


@dataclasses.dataclass
class Args:
    host: str = "0.0.0.0"
    port: int = 8000

    action_horizon: int = 25

    num_episodes: int = 1
    max_episode_steps: int = 1000


def main(args: Args) -> None:
    ws_client_policy = _websocket_client_policy.WebsocketClientPolicy(
        host=args.host,
        port=args.port,
    )
    logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}")

    metadata = ws_client_policy.get_server_metadata()
    runtime = _runtime.Runtime(
        environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")),
        agent=_policy_agent.PolicyAgent(
            policy=action_chunk_broker.ActionChunkBroker(
                policy=ws_client_policy,
                action_horizon=args.action_horizon,
            )
        ),
        subscribers=[],
        max_hz=50,
        num_episodes=args.num_episodes,
        max_episode_steps=args.max_episode_steps,
    )

    runtime.run()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, force=True)
    tyro.cli(main)


================================================
FILE: examples/aloha_real/real_env.py
================================================
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
# ruff: noqa
import collections
import time
from typing import Optional, List
import dm_env
from interbotix_xs_modules.arm import InterbotixManipulatorXS
from interbotix_xs_msgs.msg import JointSingleCommand
import numpy as np

from examples.aloha_real import constants
from examples.aloha_real import robot_utils

# This is the reset position that is used by the standard Aloha runtime.
DEFAULT_RESET_POSITION = [0, -0.96, 1.16, 0, -0.3, 0]


class RealEnv:
    """
    Environment for real robot bi-manual manipulation
    Action space:      [left_arm_qpos (6),             # absolute joint position
                        left_gripper_positions (1),    # normalized gripper position (0: close, 1: open)
                        right_arm_qpos (6),            # absolute joint position
                        right_gripper_positions (1),]  # normalized gripper position (0: close, 1: open)

    Observation space: {"qpos": Concat[ left_arm_qpos (6),          # absolute joint position
                                        left_gripper_position (1),  # normalized gripper position (0: close, 1: open)
                                        right_arm_qpos (6),         # absolute joint position
                                        right_gripper_qpos (1)]     # normalized gripper position (0: close, 1: open)
                        "qvel": Concat[ left_arm_qvel (6),         # absolute joint velocity (rad)
                                        left_gripper_velocity (1),  # normalized gripper velocity (pos: opening, neg: closing)
                                        right_arm_qvel (6),         # absolute joint velocity (rad)
                                        right_gripper_qvel (1)]     # normalized gripper velocity (pos: opening, neg: closing)
                        "images": {"cam_high": (480x640x3),        # h, w, c, dtype='uint8'
                                   "cam_low": (480x640x3),         # h, w, c, dtype='uint8'
                                   "cam_left_wrist": (480x640x3),  # h, w, c, dtype='uint8'
                                   "cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
    """

    def __init__(self, init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True):
        # reset_position = START_ARM_POSE[:6]
        self._reset_position = reset_position[:6] if reset_position else DEFAULT_RESET_POSITION

        self.puppet_bot_left = InterbotixManipulatorXS(
            robot_model="vx300s",
            group_name="arm",
            gripper_name="gripper",
            robot_name="puppet_left",
            init_node=init_node,
        )
        self.puppet_bot_right = InterbotixManipulatorXS(
            robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
        )
        if setup_robots:
            self.setup_robots()

        self.recorder_left = robot_utils.Recorder("left", init_node=False)
        self.recorder_right = robot_utils.Recorder("right", init_node=False)
        self.image_recorder = robot_utils.ImageRecorder(init_node=False)
        self.gripper_command = JointSingleCommand(name="gripper")

    def setup_robots(self):
        robot_utils.setup_puppet_bot(self.puppet_bot_left)
        robot_utils.setup_puppet_bot(self.puppet_bot_right)

    def get_qpos(self):
        left_qpos_raw = self.recorder_left.qpos
        right_qpos_raw = self.recorder_right.qpos
        left_arm_qpos = left_qpos_raw[:6]
        right_arm_qpos = right_qpos_raw[:6]
        left_gripper_qpos = [
            constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
        ]  # this is position not joint
        right_gripper_qpos = [
            constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
        ]  # this is position not joint
        return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])

    def get_qvel(self):
        left_qvel_raw = self.recorder_left.qvel
        right_qvel_raw = self.recorder_right.qvel
        left_arm_qvel = left_qvel_raw[:6]
        right_arm_qvel = right_qvel_raw[:6]
        left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
        right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
        return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])

    def get_effort(self):
        left_effort_raw = self.recorder_left.effort
        right_effort_raw = self.recorder_right.effort
        left_robot_effort = left_effort_raw[:7]
        right_robot_effort = right_effort_raw[:7]
        return np.concatenate([left_robot_effort, right_robot_effort])

    def get_images(self):
        return self.image_recorder.get_images()

    def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
        left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
        self.gripper_command.cmd = left_gripper_desired_joint
        self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)

        right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
            right_gripper_desired_pos_normalized
        )
        self.gripper_command.cmd = right_gripper_desired_joint
        self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)

    def _reset_joints(self):
        robot_utils.move_arms(
            [self.puppet_bot_left, self.puppet_bot_right], [self._reset_position, self._reset_position], move_time=1
        )

    def _reset_gripper(self):
        """Set to position mode and do position resets: first close then open. Then change back to PWM mode

        NOTE: This diverges from the original Aloha code which first opens then closes the gripper. Pi internal aloha data
        was collected with the gripper starting in the open position. Leaving the grippers fully closed was also found to
        increase the frequency of motor faults.
        """
        robot_utils.move_grippers(
            [self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
        )
        robot_utils.move_grippers(
            [self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
        )

    def get_observation(self):
        obs = collections.OrderedDict()
        obs["qpos"] = self.get_qpos()
        obs["qvel"] = self.get_qvel()
        obs["effort"] = self.get_effort()
        obs["images"] = self.get_images()
        return obs

    def get_reward(self):
        return 0

    def reset(self, *, fake=False):
        if not fake:
            # Reboot puppet robot gripper motors
            self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
            self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
            self._reset_joints()
            self._reset_gripper()
        return dm_env.TimeStep(
            step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
        )

    def step(self, action):
        state_len = int(len(action) / 2)
        left_action = action[:state_len]
        right_action = action[state_len:]
        self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
        self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
        self.set_gripper_pose(left_action[-1], right_action[-1])
        time.sleep(constants.DT)
        return dm_env.TimeStep(
            step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
        )


def get_action(master_bot_left, master_bot_right):
    action = np.zeros(14)  # 6 joint + 1 gripper, for two arms
    # Arm actions
    action[:6] = master_bot_left.dxl.joint_states.position[:6]
    action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
    # Gripper actions
    action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
    action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])

    return action


def make_real_env(init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True) -> RealEnv:
    return RealEnv(init_node, reset_position=reset_position, setup_robots=setup_robots)


================================================
FILE: examples/aloha_real/requirements.in
================================================
Pillow
dm_control
einops
h5py
matplotlib
modern_robotics
msgpack
numpy>=1.22.4,<2.0.0
opencv-python
packaging
pexpect
pyquaternion
pyrealsense2
pyyaml
requests
rospkg
tyro
websockets


================================================
FILE: examples/aloha_real/requirements.txt
================================================
# This file was autogenerated by uv via the following command:
#    uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
absl-py==2.1.0
    # via
    #   dm-control
    #   dm-env
    #   labmaze
    #   mujoco
catkin-pkg==1.0.0
    # via rospkg
certifi==2024.8.30
    # via requests
charset-normalizer==3.4.0
    # via requests
contourpy==1.1.1
    # via matplotlib
cycler==0.12.1
    # via matplotlib
distro==1.9.0
    # via rospkg
dm-control==1.0.23
    # via -r examples/aloha_real/requirements.in
dm-env==1.6
    # via dm-control
dm-tree==0.1.8
    # via
    #   dm-control
    #   dm-env
docstring-parser==0.16
    # via tyro
docutils==0.20.1
    # via catkin-pkg
einops==0.8.0
    # via -r examples/aloha_real/requirements.in
etils==1.3.0
    # via mujoco
fonttools==4.55.2
    # via matplotlib
glfw==2.8.0
    # via
    #   dm-control
    #   mujoco
h5py==3.11.0
    # via -r examples/aloha_real/requirements.in
idna==3.10
    # via requests
importlib-resources==6.4.5
    # via etils
kiwisolver==1.4.7
    # via matplotlib
labmaze==1.0.6
    # via dm-control
lxml==5.3.0
    # via dm-control
markdown-it-py==3.0.0
    # via rich
matplotlib==3.7.5
    # via -r examples/aloha_real/requirements.in
mdurl==0.1.2
    # via markdown-it-py
modern-robotics==1.1.1
    # via -r examples/aloha_real/requirements.in
msgpack==1.1.0
    # via -r examples/aloha_real/requirements.in
mujoco==3.2.3
    # via dm-control
numpy==1.24.4
    # via
    #   -r examples/aloha_real/requirements.in
    #   contourpy
    #   dm-control
    #   dm-env
    #   h5py
    #   labmaze
    #   matplotlib
    #   modern-robotics
    #   mujoco
    #   opencv-python
    #   pyquaternion
    #   scipy
opencv-python==4.10.0.84
    # via -r examples/aloha_real/requirements.in
packaging==24.2
    # via
    #   -r examples/aloha_real/requirements.in
    #   matplotlib
pexpect==4.9.0
    # via -r examples/aloha_real/requirements.in
pillow==10.4.0
    # via
    #   -r examples/aloha_real/requirements.in
    #   matplotlib
protobuf==5.29.1
    # via dm-control
ptyprocess==0.7.0
    # via pexpect
pygments==2.18.0
    # via rich
pyopengl==3.1.7
    # via
    #   dm-control
    #   mujoco
pyparsing==3.1.4
    # via
    #   catkin-pkg
    #   dm-control
    #   matplotlib
pyquaternion==0.9.9
    # via -r examples/aloha_real/requirements.in
pyrealsense2==2.55.1.6486
    # via -r examples/aloha_real/requirements.in
python-dateutil==2.9.0.post0
    # via
    #   catkin-pkg
    #   matplotlib
pyyaml==6.0.2
    # via
    #   -r examples/aloha_real/requirements.in
    #   rospkg
requests==2.32.3
    # via
    #   -r examples/aloha_real/requirements.in
    #   dm-control
rich==13.9.4
    # via tyro
rospkg==1.5.1
    # via -r examples/aloha_real/requirements.in
scipy==1.10.1
    # via dm-control
setuptools==75.3.0
    # via
    #   catkin-pkg
    #   dm-control
    #   labmaze
shtab==1.7.1
    # via tyro
six==1.17.0
    # via python-dateutil
tqdm==4.67.1
    # via dm-control
typeguard==4.4.0
    # via tyro
typing-extensions==4.12.2
    # via
    #   etils
    #   rich
    #   typeguard
    #   tyro
tyro==0.9.2
    # via -r examples/aloha_real/requirements.in
urllib3==2.2.3
    # via requests
websockets==14.1
    # via -r examples/aloha_real/requirements.in
zipp==3.20.2
    # via etils


================================================
FILE: examples/aloha_real/robot_utils.py
================================================
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
# ruff: noqa
from collections import deque
import datetime
import json
import time

from aloha.msg import RGBGrayscaleImage
from cv_bridge import CvBridge
from interbotix_xs_msgs.msg import JointGroupCommand
from interbotix_xs_msgs.msg import JointSingleCommand
import numpy as np
import rospy
from sensor_msgs.msg import JointState

from examples.aloha_real import constants


class ImageRecorder:
    def __init__(self, init_node=True, is_debug=False):
        self.is_debug = is_debug
        self.bridge = CvBridge()
        self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]

        if init_node:
            rospy.init_node("image_recorder", anonymous=True)
        for cam_name in self.camera_names:
            setattr(self, f"{cam_name}_rgb_image", None)
            setattr(self, f"{cam_name}_depth_image", None)
            setattr(self, f"{cam_name}_timestamp", 0.0)
            if cam_name == "cam_high":
                callback_func = self.image_cb_cam_high
            elif cam_name == "cam_low":
                callback_func = self.image_cb_cam_low
            elif cam_name == "cam_left_wrist":
                callback_func = self.image_cb_cam_left_wrist
            elif cam_name == "cam_right_wrist":
                callback_func = self.image_cb_cam_right_wrist
            else:
                raise NotImplementedError
            rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
            if self.is_debug:
                setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))

        self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
        time.sleep(0.5)

    def image_cb(self, cam_name, data):
        setattr(
            self,
            f"{cam_name}_rgb_image",
            self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
        )
        # setattr(
        #     self,
        #     f"{cam_name}_depth_image",
        #     self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
        # )
        setattr(
            self,
            f"{cam_name}_timestamp",
            data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
        )
        # setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
        # setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
        # cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
        if self.is_debug:
            getattr(self, f"{cam_name}_timestamps").append(
                data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
            )

    def image_cb_cam_high(self, data):
        cam_name = "cam_high"
        return self.image_cb(cam_name, data)

    def image_cb_cam_low(self, data):
        cam_name = "cam_low"
        return self.image_cb(cam_name, data)

    def image_cb_cam_left_wrist(self, data):
        cam_name = "cam_left_wrist"
        return self.image_cb(cam_name, data)

    def image_cb_cam_right_wrist(self, data):
        cam_name = "cam_right_wrist"
        return self.image_cb(cam_name, data)

    def get_images(self):
        image_dict = {}
        for cam_name in self.camera_names:
            while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
                time.sleep(0.00001)
            rgb_image = getattr(self, f"{cam_name}_rgb_image")
            depth_image = getattr(self, f"{cam_name}_depth_image")
            self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
            image_dict[cam_name] = rgb_image
            image_dict[f"{cam_name}_depth"] = depth_image
        return image_dict

    def print_diagnostics(self):
        def dt_helper(l):
            l = np.array(l)
            diff = l[1:] - l[:-1]
            return np.mean(diff)

        for cam_name in self.camera_names:
            image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
            print(f"{cam_name} {image_freq=:.2f}")
        print()


class Recorder:
    def __init__(self, side, init_node=True, is_debug=False):
        self.secs = None
        self.nsecs = None
        self.qpos = None
        self.effort = None
        self.arm_command = None
        self.gripper_command = None
        self.is_debug = is_debug

        if init_node:
            rospy.init_node("recorder", anonymous=True)
        rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
        rospy.Subscriber(
            f"/puppet_{side}/commands/joint_group",
            JointGroupCommand,
            self.puppet_arm_commands_cb,
        )
        rospy.Subscriber(
            f"/puppet_{side}/commands/joint_single",
            JointSingleCommand,
            self.puppet_gripper_commands_cb,
        )
        if self.is_debug:
            self.joint_timestamps = deque(maxlen=50)
            self.arm_command_timestamps = deque(maxlen=50)
            self.gripper_command_timestamps = deque(maxlen=50)
        time.sleep(0.1)

    def puppet_state_cb(self, data):
        self.qpos = data.position
        self.qvel = data.velocity
        self.effort = data.effort
        self.data = data
        if self.is_debug:
            self.joint_timestamps.append(time.time())

    def puppet_arm_commands_cb(self, data):
        self.arm_command = data.cmd
        if self.is_debug:
            self.arm_command_timestamps.append(time.time())

    def puppet_gripper_commands_cb(self, data):
        self.gripper_command = data.cmd
        if self.is_debug:
            self.gripper_command_timestamps.append(time.time())

    def print_diagnostics(self):
        def dt_helper(l):
            l = np.array(l)
            diff = l[1:] - l[:-1]
            return np.mean(diff)

        joint_freq = 1 / dt_helper(self.joint_timestamps)
        arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
        gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)

        print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")


def get_arm_joint_positions(bot):
    return bot.arm.core.joint_states.position[:6]


def get_arm_gripper_positions(bot):
    return bot.gripper.core.joint_states.position[6]


def move_arms(bot_list, target_pose_list, move_time=1):
    num_steps = int(move_time / constants.DT)
    curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
    traj_list = [
        np.linspace(curr_pose, target_pose, num_steps)
        for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
    ]
    for t in range(num_steps):
        for bot_id, bot in enumerate(bot_list):
            bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
        time.sleep(constants.DT)


def move_grippers(bot_list, target_pose_list, move_time):
    print(f"Moving grippers to {target_pose_list=}")
    gripper_command = JointSingleCommand(name="gripper")
    num_steps = int(move_time / constants.DT)
    curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
    traj_list = [
        np.linspace(curr_pose, target_pose, num_steps)
        for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
    ]

    with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
        for t in range(num_steps):
            d = {}
            for bot_id, bot in enumerate(bot_list):
                gripper_command.cmd = traj_list[bot_id][t]
                bot.gripper.core.pub_single.publish(gripper_command)
                d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
            f.write(json.dumps(d) + "\n")
            time.sleep(constants.DT)


def setup_puppet_bot(bot):
    bot.dxl.robot_reboot_motors("single", "gripper", True)
    bot.dxl.robot_set_operating_modes("group", "arm", "position")
    bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
    torque_on(bot)


def setup_master_bot(bot):
    bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
    bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
    torque_off(bot)


def set_standard_pid_gains(bot):
    bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
    bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)


def set_low_pid_gains(bot):
    bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
    bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)


def torque_off(bot):
    bot.dxl.robot_torque_enable("group", "arm", False)
    bot.dxl.robot_torque_enable("single", "gripper", False)


def torque_on(bot):
    bot.dxl.robot_torque_enable("group", "arm", True)
    bot.dxl.robot_torque_enable("single", "gripper", True)


# for DAgger
def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
    print("\nSyncing!")

    # activate master arms
    torque_on(master_bot_left)
    torque_on(master_bot_right)

    # get puppet arm positions
    puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
    puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)

    # get puppet gripper positions
    puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
    puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)

    # move master arms to puppet positions
    move_arms(
        [master_bot_left, master_bot_right],
        [puppet_left_qpos, puppet_right_qpos],
        move_time=1,
    )

    # move master grippers to puppet positions
    move_grippers(
        [master_bot_left, master_bot_right],
        [puppet_left_gripper, puppet_right_gripper],
        move_time=1,
    )


================================================
FILE: examples/aloha_real/video_display.py
================================================
import matplotlib.pyplot as plt
import numpy as np
from openpi_client.runtime import subscriber as _subscriber
from typing_extensions import override


class VideoDisplay(_subscriber.Subscriber):
    """Displays video frames."""

    def __init__(self) -> None:
        self._ax: plt.Axes | None = None
        self._plt_img: plt.Image | None = None

    @override
    def on_episode_start(self) -> None:
        plt.ion()
        self._ax = plt.subplot()
        self._plt_img = None

    @override
    def on_step(self, observation: dict, action: dict) -> None:
        assert self._ax is not None

        im = observation["image"][0]  # [C, H, W]
        im = np.transpose(im, (1, 2, 0))  # [H, W, C]

        if self._plt_img is None:
            self._plt_img = self._ax.imshow(im)
        else:
            self._plt_img.set_data(im)
        plt.pause(0.001)

    @override
    def on_episode_end(self) -> None:
        plt.ioff()
        plt.close()


================================================
FILE: examples/aloha_sim/Dockerfile
================================================
# Dockerfile for the Aloha simulation environment.

# Build the container:
# docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile

# Run the container:
# docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash

FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/

RUN apt-get update && \
    apt-get install -y \
    libosmesa6-dev \
    libgl1-mesa-glx \
    libglew-dev \
    libglfw3-dev \
    libgles2-mesa-dev
ENV MUJOCO_GL=egl

WORKDIR /app

# Copy from the cache instead of linking since it's a mounted volume
ENV UV_LINK_MODE=copy

# Write the virtual environment outside of the project directory so it doesn't
# leak out of the container when we mount the application code.
ENV UV_PROJECT_ENVIRONMENT=/.venv

# Copy the requirements files so we can install dependencies.
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
# This strategy is best for development-style usage.
COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml

# Install python dependencies.
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src

CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"]

================================================
FILE: examples/aloha_sim/README.md
================================================
# Run Aloha Sim

## With Docker

```bash
export SERVER_ARGS="--env ALOHA_SIM"
docker compose -f examples/aloha_sim/compose.yml up --build
```

## Without Docker

Terminal window 1:

```bash
# Create virtual environment
uv venv --python 3.10 examples/aloha_sim/.venv
source examples/aloha_sim/.venv/bin/activate
uv pip sync examples/aloha_sim/requirements.txt
uv pip install -e packages/openpi-client

# Run the simulation
MUJOCO_GL=egl python examples/aloha_sim/main.py
```

Note: If you are seeing EGL errors, you may need to install the following dependencies:

```bash
sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev
```

Terminal window 2:

```bash
# Run the server
uv run scripts/serve_policy.py --env ALOHA_SIM
```


================================================
FILE: examples/aloha_sim/compose.yml
================================================
# Run with:
# docker compose -f examples/aloha_sim/compose.yml up --build
services:
  runtime:
    image: aloha_sim
    depends_on:
      - openpi_server
    build:
      context: ../..
      dockerfile: examples/aloha_sim/Dockerfile
    init: true
    tty: true
    network_mode: host
    privileged: true
    volumes:
      - $PWD:/app
      - ../../data:/data

  openpi_server:
    image: openpi_server
    build:
      context: ../..
      dockerfile: scripts/docker/serve_policy.Dockerfile
    init: true
    tty: true
    network_mode: host
    volumes:
      - $PWD:/app
      - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
    environment:
      - SERVER_ARGS
      - OPENPI_DATA_HOME=/openpi_assets
      - IS_DOCKER=true

    # Comment out this block if not running on a machine with GPUs.
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              count: 1
              capabilities: [gpu]


================================================
FILE: examples/aloha_sim/env.py
================================================
import gym_aloha  # noqa: F401
import gymnasium
import numpy as np
from openpi_client import image_tools
from openpi_client.runtime import environment as _environment
from typing_extensions import override


class AlohaSimEnvironment(_environment.Environment):
    """An environment for an Aloha robot in simulation."""

    def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None:
        np.random.seed(seed)
        self._rng = np.random.default_rng(seed)

        self._gym = gymnasium.make(task, obs_type=obs_type)

        self._last_obs = None
        self._done = True
        self._episode_reward = 0.0

    @override
    def reset(self) -> None:
        gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
        self._last_obs = self._convert_observation(gym_obs)  # type: ignore
        self._done = False
        self._episode_reward = 0.0

    @override
    def is_episode_complete(self) -> bool:
        return self._done

    @override
    def get_observation(self) -> dict:
        if self._last_obs is None:
            raise RuntimeError("Observation is not set. Call reset() first.")

        return self._last_obs  # type: ignore

    @override
    def apply_action(self, action: dict) -> None:
        gym_obs, reward, terminated, truncated, info = self._gym.step(action["actions"])
        self._last_obs = self._convert_observation(gym_obs)  # type: ignore
        self._done = terminated or truncated
        self._episode_reward = max(self._episode_reward, reward)

    def _convert_observation(self, gym_obs: dict) -> dict:
        img = gym_obs["pixels"]["top"]
        img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224))
        # Convert axis order from [H, W, C] --> [C, H, W]
        img = np.transpose(img, (2, 0, 1))

        return {
            "state": gym_obs["agent_pos"],
            "images": {"cam_high": img},
        }


================================================
FILE: examples/aloha_sim/main.py
================================================
import dataclasses
import logging
import pathlib

import env as _env
from openpi_client import action_chunk_broker
from openpi_client import websocket_client_policy as _websocket_client_policy
from openpi_client.runtime import runtime as _runtime
from openpi_client.runtime.agents import policy_agent as _policy_agent
import saver as _saver
import tyro


@dataclasses.dataclass
class Args:
    out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos")

    task: str = "gym_aloha/AlohaTransferCube-v0"
    seed: int = 0

    action_horizon: int = 10

    host: str = "0.0.0.0"
    port: int = 8000

    display: bool = False


def main(args: Args) -> None:
    runtime = _runtime.Runtime(
        environment=_env.AlohaSimEnvironment(
            task=args.task,
            seed=args.seed,
        ),
        agent=_policy_agent.PolicyAgent(
            policy=action_chunk_broker.ActionChunkBroker(
                policy=_websocket_client_policy.WebsocketClientPolicy(
                    host=args.host,
                    port=args.port,
                ),
                action_horizon=args.action_horizon,
            )
        ),
        subscribers=[
            _saver.VideoSaver(args.out_dir),
        ],
        max_hz=50,
    )

    runtime.run()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, force=True)
    tyro.cli(main)


================================================
FILE: examples/aloha_sim/requirements.in
================================================
gym-aloha
imageio
matplotlib
msgpack
numpy>=1.22.4,<2.0.0
typing-extensions
tyro
websockets

================================================
FILE: examples/aloha_sim/requirements.txt
================================================
# This file was autogenerated by uv via the following command:
#    uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10
absl-py==2.1.0
    # via
    #   dm-control
    #   dm-env
    #   labmaze
    #   mujoco
certifi==2024.8.30
    # via requests
charset-normalizer==3.4.0
    # via requests
cloudpickle==3.1.0
    # via gymnasium
contourpy==1.3.1
    # via matplotlib
cycler==0.12.1
    # via matplotlib
dm-control==1.0.14
    # via gym-aloha
dm-env==1.6
    # via dm-control
dm-tree==0.1.8
    # via
    #   dm-control
    #   dm-env
docstring-parser==0.16
    # via tyro
farama-notifications==0.0.4
    # via gymnasium
fonttools==4.55.2
    # via matplotlib
glfw==2.8.0
    # via
    #   dm-control
    #   mujoco
gym-aloha==0.1.1
    # via -r examples/aloha_sim/requirements.in
gymnasium==1.0.0
    # via gym-aloha
idna==3.10
    # via requests
imageio==2.36.1
    # via
    #   -r examples/aloha_sim/requirements.in
    #   gym-aloha
imageio-ffmpeg==0.5.1
    # via imageio
kiwisolver==1.4.7
    # via matplotlib
labmaze==1.0.6
    # via dm-control
lxml==5.3.0
    # via dm-control
markdown-it-py==3.0.0
    # via rich
matplotlib==3.9.3
    # via -r examples/aloha_sim/requirements.in
mdurl==0.1.2
    # via markdown-it-py
msgpack==1.1.0
    # via -r examples/aloha_sim/requirements.in
mujoco==2.3.7
    # via
    #   dm-control
    #   gym-aloha
numpy==1.26.4
    # via
    #   -r examples/aloha_sim/requirements.in
    #   contourpy
    #   dm-control
    #   dm-env
    #   gymnasium
    #   imageio
    #   labmaze
    #   matplotlib
    #   mujoco
    #   scipy
packaging==24.2
    # via matplotlib
pillow==11.0.0
    # via
    #   imageio
    #   matplotlib
protobuf==5.29.1
    # via dm-control
psutil==6.1.0
    # via imageio
pygments==2.18.0
    # via rich
pyopengl==3.1.7
    # via
    #   dm-control
    #   mujoco
pyparsing==3.2.0
    # via
    #   dm-control
    #   matplotlib
python-dateutil==2.9.0.post0
    # via matplotlib
requests==2.32.3
    # via dm-control
rich==13.9.4
    # via tyro
scipy==1.14.1
    # via dm-control
setuptools==75.6.0
    # via
    #   dm-control
    #   imageio-ffmpeg
    #   labmaze
shtab==1.7.1
    # via tyro
six==1.17.0
    # via python-dateutil
tqdm==4.67.1
    # via dm-control
typeguard==4.4.1
    # via tyro
typing-extensions==4.12.2
    # via
    #   -r examples/aloha_sim/requirements.in
    #   gymnasium
    #   rich
    #   typeguard
    #   tyro
tyro==0.9.2
    # via -r examples/aloha_sim/requirements.in
urllib3==2.2.3
    # via requests
websockets==14.1
    # via -r examples/aloha_sim/requirements.in


================================================
FILE: examples/aloha_sim/saver.py
================================================
import logging
import pathlib

import imageio
import numpy as np
from openpi_client.runtime import subscriber as _subscriber
from typing_extensions import override


class VideoSaver(_subscriber.Subscriber):
    """Saves episode data."""

    def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None:
        out_dir.mkdir(parents=True, exist_ok=True)
        self._out_dir = out_dir
        self._images: list[np.ndarray] = []
        self._subsample = subsample

    @override
    def on_episode_start(self) -> None:
        self._images = []

    @override
    def on_step(self, observation: dict, action: dict) -> None:
        im = observation["images"]["cam_high"]  # [C, H, W]
        im = np.transpose(im, (1, 2, 0))  # [H, W, C]
        self._images.append(im)

    @override
    def on_episode_end(self) -> None:
        existing = list(self._out_dir.glob("out_[0-9]*.mp4"))
        next_idx = max([int(p.stem.split("_")[1]) for p in existing], default=-1) + 1
        out_path = self._out_dir / f"out_{next_idx}.mp4"

        logging.info(f"Saving video to {out_path}")
        imageio.mimwrite(
            out_path,
            [np.asarray(x) for x in self._images[:: self._subsample]],
            fps=50 // max(1, self._subsample),
        )


================================================
FILE: examples/convert_jax_model_to_pytorch.py
================================================
#!/usr/bin/env python3
"""
Load a JAX model and print all parameter keys, with optional conversion to PyTorch.

This script loads a JAX model checkpoint using orbax and can either:
1. Print out all the parameter keys in a hierarchical structure for inspection
2. Convert the JAX model to PyTorch format using our PI0Pytorch model

Usage:
    # Just inspect keys:
    python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
    python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only

    # Convert to PyTorch:
    python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
    python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output

Example:
    # pi0_droid
    python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch

    # pi0_aloha_sim
    python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch

    # pi05_droid
    python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch
"""

import json
import os
import pathlib
import shutil
from typing import Literal

from flax.nnx import traversals
import numpy as np
import orbax.checkpoint as ocp
import safetensors
import torch
import tyro

import openpi.models.gemma
import openpi.models.model
import openpi.models.pi0_config
import openpi.models_pytorch.pi0_pytorch
from openpi.training import utils
import openpi.training.config as _config


def slice_paligemma_state_dict(state_dict, config):
    """Convert PaliGemma JAX parameters to PyTorch format."""
    suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""

    # patch embeddings
    jax_key = f"img/embedding/kernel{suffix}"
    pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight"
    state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1)

    jax_key = f"img/embedding/bias{suffix}"
    pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias"
    state_dict[pytorch_key] = state_dict.pop(jax_key)

    # positional embeddings
    jax_key = f"img/pos_embedding{suffix}"
    pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight"
    state_dict[pytorch_key] = state_dict.pop(jax_key).reshape(-1, config.vision_config.hidden_size)

    # extract vision layers to be sliced at index 0. There are 27 layers in the base model.
    encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
    encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
    encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
    encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")

    encoderblock_mlp_dense0_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
    encoderblock_mlp_dense0_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
    encoderblock_mlp_dense1_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
    encoderblock_mlp_dense1_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")

    encoderblock_attention_0_key_kernel = state_dict.pop(
        f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}"
    )
    encoderblock_attention_0_key_bias = state_dict.pop(
        f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}"
    )
    encoderblock_attention_0_value_kernel = state_dict.pop(
        f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}"
    )
    encoderblock_attention_0_value_bias = state_dict.pop(
        f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}"
    )
    encoderblock_attention_0_query_kernel = state_dict.pop(
        f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}"
    )
    encoderblock_attention_0_query_bias = state_dict.pop(
        f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}"
    )
    encoderblock_attention_0_out_kernel = state_dict.pop(
        f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}"
    )
    encoderblock_attention_0_out_bias = state_dict.pop(
        f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}"
    )

    for i in range(config.vision_config.num_hidden_layers):
        state_dict[
            f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"
        ] = encoderblock_layernorm0_scale[i].transpose()
        state_dict[
            f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"
        ] = encoderblock_layernorm0_bias[i]
        state_dict[
            f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"
        ] = encoderblock_layernorm1_scale[i].transpose()
        state_dict[
            f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"
        ] = encoderblock_layernorm1_bias[i]
        state_dict[
            f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"
        ] = encoderblock_mlp_dense0_kernel[i].transpose()
        state_dict[
            f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"
        ] = encoderblock_mlp_dense0_bias[i]
        state_dict[
            f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"
        ] = encoderblock_mlp_dense1_kernel[i].transpose()
        state_dict[
            f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"
        ] = encoderblock_mlp_dense1_bias[i]
        state_dict[
            f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"
        ] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
        state_dict[
            f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"
        ] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
        state_dict[
            f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"
        ] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
        state_dict[
            f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"
        ] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
        state_dict[
            f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"
        ] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
        state_dict[
            f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"
        ] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
        state_dict[
            f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"
        ] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
        state_dict[
            f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"
        ] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)

    jax_key = f"img/Transformer/encoder_norm/scale{suffix}"
    pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight"
    state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()

    jax_key = f"img/Transformer/encoder_norm/bias{suffix}"
    pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias"
    state_dict[pytorch_key] = state_dict.pop(jax_key)

    # multimodal projector
    jax_key = f"img/head/kernel{suffix}"
    pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight"
    state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()

    jax_key = f"img/head/bias{suffix}"
    pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias"
    state_dict[pytorch_key] = state_dict.pop(jax_key)

    # text decoder (gemma)
    jax_key = f"llm/embedder/input_embedding{suffix}"
    pytorch_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
    state_dict[pytorch_key] = state_dict.pop(jax_key)

    # pop the einsum attention + mlp representations
    llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
    llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
    llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")

    llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
    llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")

    llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
    llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")

    for i in range(config.text_config.num_hidden_layers):
        q_proj_weight_reshaped = (
            llm_attention_q_einsum[i]
            .transpose(0, 2, 1)
            .reshape(
                config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
            )
        )
        state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight"] = (
            q_proj_weight_reshaped
        )

        k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
        state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] = (
            k_proj_weight_reshaped
        )
        v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
        state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] = (
            v_proj_weight_reshaped
        )

        o_proj_weight_reshaped = (
            llm_attention_attn_vec_einsum[i]
            .transpose(2, 0, 1)
            .reshape(
                config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
            )
        )
        state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = (
            o_proj_weight_reshaped
        )

        gate_proj_weight = llm_mlp_gating_einsum[i, 0]
        state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] = (
            gate_proj_weight.transpose()
        )
        up_proj_weight = llm_mlp_gating_einsum[i, 1]
        state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] = (
            up_proj_weight.transpose()
        )
        state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight"] = (
            llm_mlp_linear[i].transpose()
        )
        state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight"] = (
            llm_input_layernorm[i]
        )
        state_dict[
            f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight"
        ] = llm_post_attention_layernorm[i]

    jax_key = f"llm/final_norm/scale{suffix}"
    pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight"
    state_dict[pytorch_key] = state_dict.pop(jax_key)

    expert_dict = {}
    final_state_dict = {}

    # Expert-related keys to extract (including pi05 Dense layer parameters)
    expert_keys = [
        f"llm/final_norm_1/scale{suffix}",
        f"llm/final_norm_1/Dense_0/bias{suffix}",
        f"llm/final_norm_1/Dense_0/kernel{suffix}",
        f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
        f"llm/layers/attn/kv_einsum_1/w{suffix}",
        f"llm/layers/attn/q_einsum_1/w{suffix}",
        f"llm/layers/mlp_1/gating_einsum{suffix}",
        f"llm/layers/mlp_1/linear{suffix}",
        f"llm/layers/pre_attention_norm_1/scale{suffix}",
        f"llm/layers/pre_attention_norm_1/Dense_0/bias{suffix}",
        f"llm/layers/pre_attention_norm_1/Dense_0/kernel{suffix}",
        f"llm/layers/pre_ffw_norm_1/scale{suffix}",
        f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}",
        f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}",
    ]

    for key, value in state_dict.items():
        if key not in expert_keys:
            final_state_dict[key] = torch.from_numpy(value)
        else:
            expert_dict[key] = value

    return final_state_dict, expert_dict


def slice_gemma_state_dict(state_dict, config, *, num_expert, checkpoint_dir, pi05):
    """Convert Gemma JAX parameters to PyTorch format."""
    # Add missing attributes to config if they don't exist
    if not hasattr(config, "vocab_size"):
        config.vocab_size = 257152  # PALIGEMMA_VOCAB_SIZE
    if not hasattr(config, "hidden_size"):
        config.hidden_size = config.width
    if not hasattr(config, "num_hidden_layers"):
        config.num_hidden_layers = config.depth
    if not hasattr(config, "num_attention_heads"):
        config.num_attention_heads = config.num_heads

    suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""

    llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
    llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
    llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")

    llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
    llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")

    # Check if we have Dense layers (for pi05/adaptive normalization) or scale layers (for regular pi0)
    if "pi05" in checkpoint_dir:
        # Pi05 with adaptive normalization
        llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}")
        llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}")
        llm_input_layernorm_kernel = state_dict.pop(
            f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}"
        )
        llm_post_attention_layernorm_kernel = state_dict.pop(
            f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}"
        )
    else:
        # Regular pi0 with standard RMSNorm
        llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
        llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")

    for i in range(config.num_hidden_layers):
        q_proj_weight_reshaped = (
            llm_attention_q_einsum[i]
            .transpose(0, 2, 1)
            .reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
        )
        state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = (
            q_proj_weight_reshaped
        )

        k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
        state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = (
            k_proj_weight_reshaped
        )
        v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
        state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = (
            v_proj_weight_reshaped
        )

        o_proj_weight_reshaped = (
            llm_attention_attn_vec_einsum[i]
            .reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
            .transpose(1, 0)
        )
        state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = (
            o_proj_weight_reshaped
        )

        gate_proj_weight = llm_mlp_gating_einsum[i, 0]
        state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = (
            gate_proj_weight.transpose()
        )
        up_proj_weight = llm_mlp_gating_einsum[i, 1]
        state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = (
            up_proj_weight.transpose()
        )
        state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[
            i
        ].transpose()

        if "pi05" in checkpoint_dir:
            # Pi05 with adaptive normalization - use Dense layer parameters directly
            state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] = (
                llm_input_layernorm_bias[i]
            )
            state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias"] = (
                llm_post_attention_layernorm_bias[i]
            )
            state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight"] = (
                llm_input_layernorm_kernel[i].transpose()
            )
            state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight"] = (
                llm_post_attention_layernorm_kernel[i].transpose()
            )
        else:
            # Regular pi0 with standard RMSNorm
            state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] = (
                llm_input_layernorm[i]
            )
            state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = (
                llm_post_attention_layernorm[i]
            )

    # Handle final norm layer
    if "pi05" in checkpoint_dir:
        # Pi05 with adaptive normalization - use Dense layer parameters directly
        final_norm_bias = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/bias{suffix}")
        final_norm_kernel = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/kernel{suffix}")
        state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.bias"] = final_norm_bias
        state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose()
    else:
        # Regular pi0 with standard RMSNorm
        state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop(
            f"llm/final_norm_{num_expert}/scale{suffix}"
        )

        # state_dict["paligemma_with_expert.gemma_expert.lm_head.weight"] = embedding_vector # weights are tied.

    final_state_dict = {}
    for key, value in state_dict.items():
        if not isinstance(value, torch.Tensor):
            final_state_dict[key] = torch.from_numpy(value)
        else:
            final_state_dict[key] = value

    return final_state_dict


def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str | None = None):
    """Load and process params by restoring via JAX model loader first.
    This respects dtype conversions that occur during model restore.
    """
    # Use repository restore utility to load a pure dict of params (value suffix removed)
    params = openpi.models.model.restore_params(
        f"{checkpoint_dir}/params/", restore_type=np.ndarray, dtype=restore_precision
    )

    return {"paligemma_params": traversals.flatten_mapping(params["PaliGemma"], sep="/"), "projection_params": params}


def load_jax_model_and_print_keys(checkpoint_dir: str):
    """
    Load JAX model from checkpoint and print all parameter keys.

    Args:
        checkpoint_dir: Path to the checkpoint directory
    """
    checkpoint_dir = os.path.abspath(checkpoint_dir) if not checkpoint_dir.startswith("gs://") else checkpoint_dir
    # Initialize checkpointer
    checkpointer = ocp.PyTreeCheckpointer()
    metadata = checkpointer.metadata(f"{checkpoint_dir}/params")
    print(utils.array_tree_to_info(metadata))


def convert_pi0_checkpoint(
    checkpoint_dir: str, precision: str, output_path: str, model_config: openpi.models.pi0_config.Pi0Config
):
    """
    Convert PI0 JAX checkpoint to PyTorch format.

    Args:
        checkpoint_dir: Path to the JAX checkpoint
        precision: Model precision (float32, bfloat16, float16)
        output_path: Path to save the converted PyTorch model
        model_config: Model config
    """
    print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}")
    print(f"Model config: {model_config}")

    # Break down orbax ckpts by restoring via JAX to respect dtype
    initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32")

    # Process projection params
    if model_config.pi05:
        keys = [
            "action_in_proj",
            "action_out_proj",
            "time_mlp_in",
            "time_mlp_out",
        ]
    else:
        keys = [
            "state_proj",
            "action_in_proj",
            "action_out_proj",
            "action_time_mlp_in",
            "action_time_mlp_out",
        ]

    projection_params = {}
    for key in keys:
        kernel_params = initial_params["projection_params"][key]["kernel"]
        bias_params = initial_params["projection_params"][key]["bias"]
        if isinstance(kernel_params, dict):
            weight = kernel_params["value"]
            bias = bias_params["value"]
        else:
            weight = kernel_params
            bias = bias_params

        pytorch_weight_key = f"{key}.weight"
        pytorch_bias_key = f"{key}.bias"

        projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T
        projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias))

    # Create configs based on checkpoint path
    # All models use the same PaliGemma config structure
    class PaliGemmaConfig:
        def __init__(self):
            self.vision_config = type(
                "obj",
                (object,),
                {
                    "hidden_size": 1152,
                    "num_hidden_layers": 27,
                    "num_attention_heads": 16,
                    "intermediate_size": 4304,
                    "patch_size": 14,
                    "projection_dim": 2048,
                },
            )()
            self.text_config = type(
                "obj",
                (object,),
                {
                    "hidden_size": 2048,
                    "num_hidden_layers": 18,
                    "num_attention_heads": 8,
                    "head_dim": 256,
                    "intermediate_size": 16384,
                },
            )()

    paligemma_config = PaliGemmaConfig()
    action_expert_config = openpi.models.gemma.get_config("gemma_300m")

    # Process PaliGemma weights
    paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config)

    # Process Gemma weights from expert_params
    gemma_params = slice_gemma_state_dict(
        expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir, pi05=model_config.pi05
    )

    # Instantiate model
    pi0_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_config)

    # Combine all parameters (no prefix needed for our model structure)
    all_params = {**paligemma_params, **gemma_params, **projection_params}

    # Load state dict
    pi0_model.load_state_dict(all_params, strict=False)

    if precision == "float32":
        pi0_model = pi0_model.to(torch.float32)
    elif precision == "bfloat16":
        pi0_model = pi0_model.to(torch.bfloat16)
    else:
        raise ValueError(f"Invalid precision: {precision}")

    # Save the converted model using safetensors
    os.makedirs(output_path, exist_ok=True)

    # Save model weights as SafeTensors using save_model to handle tied weights
    safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors"))

    # Copy assets folder if it exists
    assets_source = pathlib.Path(checkpoint_dir).parent / "assets"
    if assets_source.exists():
        assets_dest = pathlib.Path(output_path) / "assets"
        if assets_dest.exists():
            shutil.rmtree(assets_dest)
        shutil.copytree(assets_source, assets_dest)

    # Save config as JSON for reference
    config_dict = {
        "action_dim": model_config.action_dim,
        "action_horizon": model_config.action_horizon,
        "paligemma_variant": model_config.paligemma_variant,
        "action_expert_variant": model_config.action_expert_variant,
        "precision": precision,
    }
    with open(os.path.join(output_path, "config.json"), "w") as f:
        json.dump(config_dict, f, indent=2)

    print("Model conversion completed successfully!")
    print(f"Model saved to {output_path}")


def main(
    checkpoint_dir: str,
    config_name: str,
    output_path: str | None = None,
    precision: Literal["float32", "bfloat16", "float16"] = "bfloat16",
    *,
    inspect_only: bool = False,
):
    """Load JAX model and optionally convert to PyTorch.

    Args:
        checkpoint_dir: Path to the JAX checkpoint directory
        output_path: Path to save converted PyTorch model (required for conversion)
        precision: Precision for model conversion
        inspect_only: Only inspect parameter keys, don't convert
    """
    model_config = _config.get_config(config_name).model
    if not isinstance(model_config, openpi.models.pi0_config.Pi0Config):
        raise ValueError(f"Config {config_name} is not a Pi0Config")
    if inspect_only:
        load_jax_model_and_print_keys(checkpoint_dir)
    else:
        if not output_path:
            print("Error: --output_path is required for conversion. Use --inspect_only to only view keys.")
            return
        convert_pi0_checkpoint(checkpoint_dir, precision, output_path, model_config)


if __name__ == "__main__":
    tyro.cli(main)


================================================
FILE: examples/droid/README.md
================================================
# DROID Policies in openpi

We offer instructions for:
- [Running inference for our best $pi_{0.5}$-DROID policy](./README.md#running-droid-inference)
- [Running inference for other pre-trained DROID policies ($\pi_0$, $\pi_0$-FAST, ...)](./README.md#running-roboarena-baseline-policies)
- [Pre-training *generalist* policies on the *full* DROID dataset](./README_train.md#training-on-droid)
- [Fine-tuning expert $\pi_{0.5}$ on your custom DROID dataset](./README_train.md#fine-tuning-on-custom-droid-datasets)

## Running DROID Inference

This example shows how to run the fine-tuned $\pi_{0.5}$-DROID model on the [DROID robot platform](https://github.com/droid-dataset/droid). Based on the [public RoboArena benchmark](https://robo-arena.github.io/leaderboard), this is currently our strongest generalist DROID policy. 


### Step 1: Start a policy server

Since the DROID control laptop does not have a powerful GPU, we will start a remote policy server on a different machine with a more powerful GPU and then query it from the DROID control laptop during inference.

1. On a machine with a powerful GPU (~NVIDIA 4090), clone and install the `openpi` repository following the instructions in the [README](https://github.com/Physical-Intelligence/openpi).
2. Start the OpenPI server via the following command:

```bash
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_droid --policy.dir=gs://openpi-assets/checkpoints/pi05_droid
```

You can also run the equivalent command below:

```bash
uv run scripts/serve_policy.py --env=DROID
```

### Step 2: Run the DROID robot

1. Make sure you have the most recent version of the DROID package installed on both the DROID control laptop and the NUC.
2. On the control laptop, activate your DROID conda environment.
3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`.
4. Install `tyro`, which we will use for command line parsing: `pip install tyro`.
5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory.
6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explorer` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping <server_ip>` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"].

```bash
python3 scripts/main.py --remote_host=<server_ip> --remote_port=<server_port> --external_camera="left"
```

The script will ask you to enter a free-form language instruction for the robot to follow. Make sure to point the cameras at the scene you want the robot to interact with. You _do not_ need to carefully control camera angle, object positions, etc. The policy is fairly robust in our experience. Happy prompting!

## Troubleshooting

| Issue | Solution |
|-------|----------|
| Cannot reach policy server | Make sure the server is running and the IP and port are correct. You can check that the server machine is reachable by running `ping <server_ip>` from the DROID laptop. |
| Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. |
| Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). |
| Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) |


## Running Other Policies

We provide configs for running the baseline DROID policies from the [RoboArena](https://robo-arena.github.io/) paper. Simply run the commands below to start inference servers for the respective policies. Then follow the instructions above to run evaluation on the DROID robot.

```
# Train from pi0-FAST, using FAST tokenizer
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid

# Train from pi0, using flow matching
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_droid

# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_binning_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_binning_droid

# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_droid

# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_specialist_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_specialist_droid

# Trained from PaliGemma, using FSQ tokenizer.
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_vq_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_vq_droid

# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_diffusion_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_diffusion_droid
```

You can find the inference configs in [roboarena_config.py](../../src/openpi/training/misc/roboarena_config.py).


================================================
FILE: examples/droid/README_train.md
================================================
# Training on DROID

Here we describe how to fine-tune the pi0.5 model on the *full* DROID dataset. This is an approximate open-source reproduction of the pi05-DROID training pipeline.
(small differences in data loading and the used action space) -- For a tutorial on how to fine-tune your model with a smaller, custom dataset collected on the DROID platform, see below.

In contrast to the rest of openpi, which uses LeRobot for data loading, we need to use RLDS as the data format for full DROID training (since at the moment LeRobot isn't scalable enough 
for larger datasets like DROID -- they are working on improving it though). Below, we provide instructions for updating your openpi environment for RLDS data loading and where to download the DROID dataset.

## Install

We need a few additional dependencies for RLDS data loading. Run:
```bash
uv sync --group rlds
```

## Download DROID dataset

You can download the DROID dataset with the following command (after installing the `gsutil` google cloud CLI):
```
gsutil -m cp -r gs://gresearch/robotics/droid/1.0.1 <your_download_path>/droid/1.0.1
```

Note that downloading version 1.0.1 is important (not v1.0.0): it contains the complete set of language annotations (~75k episodes) while v1.0.0 only has annotations for 30k episodes. If for some reason you would like to use another version, modify the line `version="1.0.1"` in the `DroidRldsDataset` object [here](src/openpi/training/droid_rlds_dataset.py).

You will need 1.8TB of disk storage to download the DROID RLDS dataset.

## Run

First, change the `rlds_data_dir` path in your `TrainConfig` to the directory that you downloaded the `droid` dataset into (see [src/openpi/training/config.py](src/openpi/training/config.py)).

Then, compute normalization statistics (this will take ~10 minutes):
```bash
uv run --group rlds scripts/compute_norm_stats.py --config-name pi05_full_droid_finetune --max-frames 10_000_000
```

Run training:
```bash
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run --group rlds scripts/train.py pi05_full_droid_finetune --exp-name=my_experiment --overwrite
```

**Note**: The original pi0.5-DROID model was trained with joint velocity actions.
Joint velocity actions are not compatible with simulated evaluation environments (much harder to simulate). 
Thus, we do not recommend training with joint velocity actions and instead use joint position actions here.


## Compute Requirements

Our DROID training config requires approximately 2 days on 8x H100 GPUs for convergence (100k iterations, bs256, approx. 1 epoch).
If you start from PaliGemma instead of pi0 initialization, plan with ~5 days on 8x H100s (240k iterations, i.e. 3 epochs).

We have experimented with LoRA for cheaper finetuning, but haven't found the policies to perform well so far.


## Data Filtering

Like any diverse real-robot dataset, the DROID dataset isn't perfectly "clean" and we have found data filtering to significantly improve policy performance. Concretely, the DROID dataset contains many *idle* timesteps in which the robot does not move (in part due to the VR teleoperation interface that was used during data collection, we will not go into too much detail here). Appropriate filtering of these idle transitions can improve policy performance.

By default, our openpi training recipe implements the same idle filter used to train all pi-DROID models. We implement it by pre-computing which dataset indices to sample during training. You can check [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) for how we compute these indices. Roughly speaking, we filter any time steps for which the next chunk of actions would be largely idle. During training, our code automatically pulls our pre-computed list of indices from cloud storage and applies them. If you want to modify the idle filter / create your custom sampling logic, you can modify our script to generate a new index list and provide it via the `filter_dict_path="<path_to_filter_dict>"` argument in [src/openpi/training/config.py](src/openpi/training/config.py).

**Note**: our list of filtering indices is only valid for the `droid/1.0.1` dataset mentioned in the download section above, and will not provide valid filtering for any other version of the DROID dataset, so make sure you download the dataset above! If you have a custom DROID version, you can rerun the [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) script to generate a new list of sampling indices.

## RoboArena

Consider submitting your DROID policies to the [RoboArena benchmark](https://robo-arena.github.io/), which allows you to evaluate your policies on diverse tasks & scenes, **in the real world**! :)

If you have questions about RoboArena, please email [karl.pertsch@gmail.com](mailto:karl.pertsch@gmail.com).


# Fine-Tuning on Custom DROID Datasets

Here we describe how to fine-tune a model on a custom (smaller) dataset collected on the DROID platform. Like for other datasets, we will first convert the custom DROID dataset to LeRobot and then fine-tune a model (pi05-droid) on it.

Note: We use LeRobot here, since we assume the custom DROID fine-tuning dataset to be relatively small (<10s of hours). For larger datasets (like the full DROID dataset) we recommend using RLDS for it's better efficiency (see the example above).


## Step 1: Converting your custom DROID dataset to LeRobot

We will use a small subset of the real DROID dataset for this example. This is a subset of just 30 demonstrations -- we assume that you will use your own dataset instead, but here is the command to download our subset (1.6GB):
```
gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04 <your_target_path>
```

We will also download the language annotations for the DROID dataset so we can pair our demonstrations with language instructions. Again, for your own data you can manually enter your language instructions and don't need to download our annotations. To download the DROID language annotations (12MB), run:
```
gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/aggregated-annotations-030724.json <your_target_dir>
```

For your own dataset, make sure that each episode's directory contains a folder called `recordings/MP4` -- if not, you need to first run the MP4 video extraction (from SVO files) using the script [here](https://github.com/droid-dataset/droid/blob/main/scripts/convert/svo_to_mp4.py).

Now, we will use the `convert_droid_to_lerobot.py` script to create a LeRobot version of this dataset (takes <5min for the 30 demonstrations):
```
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir <your_target_path>
```

## Step 2: Run fine-tuning with your custom dataset

Now we can run fine-tuning with our converted custom dataset. We provide an example config for fine-tuning `pi05_droid` on the custom dataset we created. 
You can modify the config easily to work with other base models, or use your custom DROID dataset in `config.py` (seach for `pi05_droid_finetune`).

To launch training:
```
uv run scripts/train.py pi05_droid_finetune --exp-name=my_experiment --overwrite
```

Once trained, you can follow the instructions in [`examples/droid/README.md`](examples/droid/README.md) to serve the policy and run it on the robot.



================================================
FILE: examples/droid/compute_droid_nonidle_ranges.py
================================================
"""
Iterates through the DROID dataset and creates a json mapping from episode unique IDs to ranges of time steps
that should be sampled during training (all others are filtered out).

Filtering logic:
We look for ranges of consecutive steps that contain at most min_idle_len consecutive idle frames
(default to 7 -- as most DROID action-chunking policies run the first 8 actions generated in each chunk, filtering
this way means the policy will not get stuck outputting stationary actions). Additionally, we also only keep non-idle
ranges of length at least min_non_idle_len (default to 16 frames = ~1 second), while also removing the last
filter_last_n_in_ranges frames from the end of each range (as those all correspond to action chunks with many idle actions).

This leaves us with trajectory segments consisting of contiguous, significant movement. Training on this filtered set
yields policies that output fewer stationary actions (i.e., get "stuck" in states less).
"""

import json
import os
from pathlib import Path

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from tqdm import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = ""  # Set to the GPU you want to use, or leave empty for CPU

builder = tfds.builder_from_directory(
    # path to the `droid` directory (not its parent)
    builder_dir="<path_to_droid_dataset_tfds_files>",
)
ds = builder.as_dataset(split="train", shuffle_files=False)
tf.data.experimental.ignore_errors(ds)

keep_ranges_path = "<path_to_where_to_save_the_json>"

min_idle_len = 7  # If more than this number of consecutive idle frames, filter all of them out
min_non_idle_len = 16  # If fewer than this number of consecutive non-idle frames, filter all of them out
filter_last_n_in_ranges = 10  # When using a filter dict, remove this many frames from the end of each range

keep_ranges_map = {}
if Path(keep_ranges_path).exists():
    with Path(keep_ranges_path).open("r") as f:
        keep_ranges_map = json.load(f)
    print(f"Resuming from {len(keep_ranges_map)} episodes already processed")

for ep_idx, ep in enumerate(tqdm(ds)):
    recording_folderpath = ep["episode_metadata"]["recording_folderpath"].numpy().decode()
    file_path = ep["episode_metadata"]["file_path"].numpy().decode()

    key = f"{recording_folderpath}--{file_path}"
    if key in keep_ranges_map:
        continue

    joint_velocities = [step["action_dict"]["joint_velocity"].numpy() for step in ep["steps"]]
    joint_velocities = np.array(joint_velocities)

    is_idle_array = np.hstack(
        [np.array([False]), np.all(np.abs(joint_velocities[1:] - joint_velocities[:-1]) < 1e-3, axis=1)]
    )

    # Find what steps go from idle to non-idle and vice-versa
    is_idle_padded = np.concatenate(
        [[False], is_idle_array, [False]]
    )  # Start and end with False, so idle at first step is a start of motion

    is_idle_diff = np.diff(is_idle_padded.astype(int))
    is_idle_true_starts = np.where(is_idle_diff == 1)[0]  # +1 transitions --> going from idle to non-idle
    is_idle_true_ends = np.where(is_idle_diff == -1)[0]  # -1 transitions --> going from non-idle to idle

    # Find which steps correspond to idle segments of length at least min_idle_len
    true_segment_masks = (is_idle_true_ends - is_idle_true_starts) >= min_idle_len
    is_idle_true_starts = is_idle_true_starts[true_segment_masks]
    is_idle_true_ends = is_idle_true_ends[true_segment_masks]

    keep_mask = np.ones(len(joint_velocities), dtype=bool)
    for start, end in zip(is_idle_true_starts, is_idle_true_ends, strict=True):
        keep_mask[start:end] = False

    # Get all non-idle ranges of at least 16
    # Same logic as above, but for keep_mask, allowing us to filter out contiguous ranges of length < min_non_idle_len
    keep_padded = np.concatenate([[False], keep_mask, [False]])

    keep_diff = np.diff(keep_padded.astype(int))
    keep_true_starts = np.where(keep_diff == 1)[0]  # +1 transitions --> going from filter out to keep
    keep_true_ends = np.where(keep_diff == -1)[0]  # -1 transitions --> going from keep to filter out

    # Find which steps correspond to non-idle segments of length at least min_non_idle_len
    true_segment_masks = (keep_true_ends - keep_true_starts) >= min_non_idle_len
    keep_true_starts = keep_true_starts[true_segment_masks]
    keep_true_ends = keep_true_ends[true_segment_masks]

    # Add mapping from episode unique ID key to list of non-idle ranges to keep
    keep_ranges_map[key] = []
    for start, end in zip(keep_true_starts, keep_true_ends, strict=True):
        keep_ranges_map[key].append((int(start), int(end) - filter_last_n_in_ranges))

    if ep_idx % 1000 == 0:
        with Path(keep_ranges_path).open("w") as f:
            json.dump(keep_ranges_map, f)

print("Done!")
with Path(keep_ranges_path).open("w") as f:
    json.dump(keep_ranges_map, f)


================================================
FILE: examples/droid/convert_droid_data_to_lerobot.py
================================================
"""
Minimal example script for converting a dataset collected on the DROID platform to LeRobot format.

Usage:
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data

If you want to push your dataset to the Hugging Face Hub, you can use the following command:
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub

The resulting dataset will get saved to the $LEROBOT_HOME directory.
"""

from collections import defaultdict
import copy
import glob
import json
from pathlib import Path
import shutil

import cv2
import h5py
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import numpy as np
from PIL import Image
from tqdm import tqdm
import tyro

REPO_NAME = "your_hf_username/my_droid_dataset"  # Name of the output dataset, also used for the Hugging Face Hub


def resize_image(image, size):
    image = Image.fromarray(image)
    return np.array(image.resize(size, resample=Image.BICUBIC))


def main(data_dir: str, *, push_to_hub: bool = False):
    # Clean up any existing dataset in the output directory
    output_path = HF_LEROBOT_HOME / REPO_NAME
    if output_path.exists():
        shutil.rmtree(output_path)
    data_dir = Path(data_dir)

    # Create LeRobot dataset, define features to store
    # We will follow the DROID data naming conventions here.
    # LeRobot assumes that dtype of image data is `image`
    dataset = LeRobotDataset.create(
        repo_id=REPO_NAME,
        robot_type="panda",
        fps=15,  # DROID data is typically recorded at 15fps
        features={
            # We call this "left" since we will only use the left stereo camera (following DROID RLDS convention)
            "exterior_image_1_left": {
                "dtype": "image",
                "shape": (180, 320, 3),  # This is the resolution used in the DROID RLDS dataset
                "names": ["height", "width", "channel"],
            },
            "exterior_image_2_left": {
                "dtype": "image",
                "shape": (180, 320, 3),
                "names": ["height", "width", "channel"],
            },
            "wrist_image_left": {
                "dtype": "image",
                "shape": (180, 320, 3),
                "names": ["height", "width", "channel"],
            },
            "joint_position": {
                "dtype": "float32",
                "shape": (7,),
                "names": ["joint_position"],
            },
            "gripper_position": {
                "dtype": "float32",
                "shape": (1,),
                "names": ["gripper_position"],
            },
            "actions": {
                "dtype": "float32",
                "shape": (8,),  # We will use joint *velocity* actions here (7D) + gripper position (1D)
                "names": ["actions"],
            },
        },
        image_writer_threads=10,
        image_writer_processes=5,
    )

    # Load language annotations
    # Note: we load the DROID language annotations for this example, but you can manually define them for your own data
    with (data_dir / "aggregated-annotations-030724.json").open() as f:
        language_annotations = json.load(f)

    # Loop over raw DROID fine-tuning datasets and write episodes to the LeRobot dataset
    # We assume the following directory structure:
    # RAW_DROID_PATH/
    #   - <...>/
    #     - recordings/
    #        - MP4/
    #          - <camera_id>.mp4  # single-view video of left stereo pair camera
    #     - trajectory.hdf5
    #   - <...>/
    episode_paths = list(data_dir.glob("**/trajectory.h5"))
    print(f"Found {len(episode_paths)} episodes for conversion")

    # We will loop over each dataset_name and write episodes to the LeRobot dataset
    for episode_path in tqdm(episode_paths, desc="Converting episodes"):
        # Load raw data
        recording_folderpath = episode_path.parent / "recordings" / "MP4"
        trajectory = load_trajectory(str(episode_path), recording_folderpath=str(recording_folderpath))

        # To load the language instruction, we need to parse out the episode_id from the metadata file
        # Again, you can modify this step for your own data, to load your own language instructions
        metadata_filepath = next(iter(episode_path.parent.glob("metadata_*.json")))
        episode_id = metadata_filepath.name.split(".")[0].split("_")[-1]
        language_instruction = language_annotations.get(episode_id, {"language_instruction1": "Do something"})[
            "language_instruction1"
        ]
        print(f"Converting episode with language instruction: {language_instruction}")

        # Write to LeRobot dataset
        for step in trajectory:
            camera_type_dict = step["observation"]["camera_type"]
            wrist_ids = [k for k, v in camera_type_dict.items() if v == 0]
            exterior_ids = [k for k, v in camera_type_dict.items() if v != 0]
            dataset.add_frame(
                {
                    # Note: need to flip BGR --> RGB for loaded images
                    "exterior_image_1_left": resize_image(
                        step["observation"]["image"][exterior_ids[0]][..., ::-1], (320, 180)
                    ),
                    "exterior_image_2_left": resize_image(
                        step["observation"]["image"][exterior_ids[1]][..., ::-1], (320, 180)
                    ),
                    "wrist_image_left": resize_image(step["observation"]["image"][wrist_ids[0]][..., ::-1], (320, 180)),
                    "joint_position": np.asarray(
                        step["observation"]["robot_state"]["joint_positions"], dtype=np.float32
                    ),
                    "gripper_position": np.asarray(
                        step["observation"]["robot_state"]["gripper_position"][None], dtype=np.float32
                    ),
                    # Important: we use joint velocity actions here since pi05-droid was pre-trained on joint velocity actions
                    "actions": np.concatenate(
                        [step["action"]["joint_velocity"], step["action"]["gripper_position"][None]], dtype=np.float32
                    ),
                    "task": language_instruction,
                }
            )
        dataset.save_episode()

    # Optionally push to the Hugging Face Hub
    if push_to_hub:
        dataset.push_to_hub(
            tags=["libero", "panda", "rlds"],
            private=False,
            push_videos=True,
            license="apache-2.0",
        )


##########################################################################################################
################ The rest of this file are functions to parse the raw DROID data #########################
################ You don't need to worry about understanding this part           #########################
################ It was copied from here: https://github.com/JonathanYang0127/r2d2_rlds_dataset_builder/blob/parallel_convert/r2_d2/r2_d2.py
##########################################################################################################


camera_type_dict = {
    "hand_camera_id": 0,
    "varied_camera_1_id": 1,
    "varied_camera_2_id": 1,
}

camera_type_to_string_dict = {
    0: "hand_camera",
    1: "varied_camera",
    2: "fixed_camera",
}


def get_camera_type(cam_id):
    if cam_id not in camera_type_dict:
        return None
    type_int = camera_type_dict[cam_id]
    return camera_type_to_string_dict[type_int]


class MP4Reader:
    def __init__(self, filepath, serial_number):
        # Save Parameters #
        self.serial_number = serial_number
        self._index = 0

        # Open Video Reader #
        self._mp4_reader = cv2.VideoCapture(filepath)
        if not self._mp4_reader.isOpened():
            raise RuntimeError("Corrupted MP4 File")

    def set_reading_parameters(
        self,
        image=True,  # noqa: FBT002
        concatenate_images=False,  # noqa: FBT002
        resolution=(0, 0),
        resize_func=None,
    ):
        # Save Parameters #
        self.image = image
        self.concatenate_images = concatenate_images
        self.resolution = resolution
        self.resize_func = cv2.resize
        self.skip_reading = not image
        if self.skip_reading:
            return

    def get_frame_resolution(self):
        width = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH)
        height = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT)
        return (width, height)

    def get_frame_count(self):
        if self.skip_reading:
            return 0
        return int(self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT))

    def set_frame_index(self, index):
        if self.skip_reading:
            return

        if index < self._index:
            self._mp4_reader.set(cv2.CAP_PROP_POS_FRAMES, index - 1)
            self._index = index

        while self._index < index:
            self.read_camera(ignore_data=True)

    def _process_frame(self, frame):
        frame = copy.deepcopy(frame)
        if self.resolution == (0, 0):
            return frame
        return self.resize_func(frame, self.resolution)

    def read_camera(self, ignore_data=False, correct_timestamp=None):  # noqa: FBT002
        # Skip if Read Unnecessary #
        if self.skip_reading:
            return {}

        # Read Camera #
        success, frame = self._mp4_reader.read()

        self._index += 1
        if not success:
            return None
        if ignore_data:
            return None

        # Return Data #
        data_dict = {}

        if self.concatenate_images or "stereo" not in self.serial_number:
            data_dict["image"] = {self.serial_number: self._process_frame(frame)}
        else:
            single_width = frame.shape[1] // 2
            data_dict["image"] = {
                self.serial_number + "_left": self._process_frame(frame[:, :single_width, :]),
                self.serial_number + "_right": self._process_frame(frame[:, single_width:, :]),
            }

        return data_dict

    def disable_camera(self):
        if hasattr(self, "_mp4_reader"):
            self._mp4_reader.release()


class RecordedMultiCameraWrapper:
    def __init__(self, recording_folderpath, camera_kwargs={}):  # noqa: B006
        # Save Camera Info #
        self.camera_kwargs = camera_kwargs

        # Open Camera Readers #
        mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4")
        all_filepaths = mp4_filepaths

        self.camera_dict = {}
        for f in all_filepaths:
            serial_number = f.split("/")[-1][:-4]
            cam_type = get_camera_type(serial_number)
            camera_kwargs.get(cam_type, {})

            if f.endswith(".mp4"):
                Reader = MP4Reader  # noqa: N806
            else:
                raise ValueError

            self.camera_dict[serial_number] = Reader(f, serial_number)

    def read_cameras(self, index=None, camera_type_dict={}, timestamp_dict={}):  # noqa: B006
        full_obs_dict = defaultdict(dict)

        # Read Cameras In Randomized Order #
        all_cam_ids = list(self.camera_dict.keys())
        # random.shuffle(all_cam_ids)

        for cam_id in all_cam_ids:
            if "stereo" in cam_id:
                continue
            try:
                cam_type = camera_type_dict[cam_id]
            except KeyError:
                print(f"{self.camera_dict} -- {camera_type_dict}")
                raise ValueError(f"Camera type {cam_id} not found in camera_type_dict")  # noqa: B904
            curr_cam_kwargs = self.camera_kwargs.get(cam_type, {})
            self.camera_dict[cam_id].set_reading_parameters(**curr_cam_kwargs)

            timestamp = timestamp_dict.get(cam_id + "_frame_received", None)
            if index is not None:
                self.camera_dict[cam_id].set_frame_index(index)

            data_dict = self.camera_dict[cam_id].read_camera(correct_timestamp=timestamp)

            # Process Returned Data #
            if data_dict is None:
                return None
            for key in data_dict:
                full_obs_dict[key].update(data_dict[key])

        return full_obs_dict


def get_hdf5_length(hdf5_file, keys_to_ignore=[]):  # noqa: B006
    length = None

    for key in hdf5_file:
        if key in keys_to_ignore:
            continue

        curr_data = hdf5_file[key]
        if isinstance(curr_data, h5py.Group):
            curr_length = get_hdf5_length(curr_data, keys_to_ignore=keys_to_ignore)
        elif isinstance(curr_data, h5py.Dataset):
            curr_length = len(curr_data)
        else:
            raise ValueError

        if length is None:
            length = curr_length
        assert curr_length == length

    return length


def load_hdf5_to_dict(hdf5_file, index, keys_to_ignore=[]):  # noqa: B006
    data_dict = {}

    for key in hdf5_file:
        if key in keys_to_ignore:
            continue

        curr_data = hdf5_file[key]
        if isinstance(curr_data, h5py.Group):
            data_dict[key] = load_hdf5_to_dict(curr_data, index, keys_to_ignore=keys_to_ignore)
        elif isinstance(curr_data, h5py.Dataset):
            data_dict[key] = curr_data[index]
        else:
            raise ValueError

    return data_dict


class TrajectoryReader:
    def __init__(self, filepath, read_images=True):  # noqa: FBT002
        self._hdf5_file = h5py.File(filepath, "r")
        is_video_folder = "observations/videos" in self._hdf5_file
        self._read_images = read_images and is_video_folder
        self._length = get_hdf5_length(self._hdf5_file)
        self._video_readers = {}
        self._index = 0

    def length(self):
        return self._length

    def read_timestep(self, index=None, keys_to_ignore=[]):  # noqa: B006
        # Make Sure We Read Within Range #
        if index is None:
            index = self._index
        else:
            assert not self._read_images
            self._index = index
        assert index < self._length

        # Load Low Dimensional Data #
        keys_to_ignore = [*keys_to_ignore.copy(), "videos"]
        timestep = load_hdf5_to_dict(self._hdf5_file, self._index, keys_to_ignore=keys_to_ignore)

        # Increment Read Index #
        self._index += 1

        # Return Timestep #
        return timestep

    def close(self):
        self._hdf5_file.close()


def load_trajectory(
    filepath=None,
    read_cameras=True,  # noqa: FBT002
    recording_folderpath=None,
    camera_kwargs={},  # noqa: B006
    remove_skipped_steps=False,  # noqa: FBT002
    num_samples_per_traj=None,
    num_samples_per_traj_coeff=1.5,
):
    read_recording_folderpath = read_cameras and (recording_folderpath is not None)

    traj_reader = TrajectoryReader(filepath)
    if read_recording_folderpath:
        camera_reader = RecordedMultiCameraWrapper(recording_folderpath, camera_kwargs)

    horizon = traj_reader.length()
    timestep_list = []

    # Choose Timesteps To Save #
    if num_samples_per_traj:
        num_to_save = num_samples_per_traj
        if remove_skipped_steps:
            num_to_save = int(num_to_save * num_samples_per_traj_coeff)
        max_size = min(num_to_save, horizon)
        indices_to_save = np.sort(np.random.choice(horizon, size=max_size, replace=False))
    else:
        indices_to_save = np.arange(horizon)

    # Iterate Over Trajectory #
    for i in indices_to_save:
        # Get HDF5 Data #
        timestep = traj_reader.read_timestep(index=i)

        # If Applicable, Get Recorded Data #
        if read_recording_folderpath:
            timestamp_dict = timestep["observation"]["timestamp"]["cameras"]
            camera_type_dict = {
                k: camera_type_to_string_dict[v] for k, v in timestep["observation"]["camera_type"].items()
            }
            camera_obs = camera_reader.read_cameras(
                index=i, camera_type_dict=camera_type_dict, timestamp_dict=timestamp_dict
            )
            camera_failed = camera_obs is None

            # Add Data To Timestep If Successful #
            if camera_failed:
                break
            timestep["observation"].update(camera_obs)

        # Filter Steps #
        step_skipped = not timestep["observation"]["controller_info"].get("movement_enabled", True)
        delete_skipped_step = step_skipped and remove_skipped_steps

        # Save Filtered Timesteps #
        if delete_skipped_step:
            del timestep
        else:
            timestep_list.append(timestep)

    # Remove Extra Transitions #
    timestep_list = np.array(timestep_list)
    if (num_samples_per_traj is not None) and (len(timestep_list) > num_samples_per_traj):
        ind_to_keep = np.random.choice(len(timestep_list), size=num_samples_per_traj, replace=False)
        timestep_list = timestep_list[ind_to_keep]

    # Close Readers #
    traj_reader.close()

    # Return Data #
    return timestep_list


if __name__ == "__main__":
    tyro.cli(main)


================================================
FILE: examples/droid/main.py
================================================
# ruff: noqa

import contextlib
import dataclasses
import datetime
import faulthandler
import os
import signal
import time
from moviepy.editor import ImageSequenceClip
import numpy as np
from openpi_client import image_tools
from openpi_client import websocket_client_policy
import pandas as pd
from PIL import Image
from droid.robot_env import RobotEnv
import tqdm
import tyro

faulthandler.enable()

# DROID data collection frequency -- we slow down execution to match this frequency
DROID_CONTROL_FREQUENCY = 15


@dataclasses.dataclass
class Args:
    # Hardware parameters
    left_camera_id: str = "<your_camera_id>"  # e.g., "24259877"
    right_camera_id: str = "<your_camera_id>"  # e.g., "24514023"
    wrist_camera_id: str = "<your_camera_id>"  # e.g., "13062452"

    # Policy parameters
    external_camera: str | None = (
        None  # which external camera should be fed to the policy, choose from ["left", "right"]
    )

    # Rollout parameters
    max_timesteps: int = 600
    # How many actions to execute from a predicted action chunk before querying policy server again
    # 8 is usually a good default (equals 0.5 seconds of action execution).
    open_loop_horizon: int = 8

    # Remote server parameters
    remote_host: str = "0.0.0.0"  # point this to the IP address of the policy server, e.g., "192.168.1.100"
    remote_port: int = (
        8000  # point this to the port of the policy server, default server port for openpi servers is 8000
    )


# We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is
# waiting for a new action chunk, it will raise an exception and the server connection dies.
# This context manager temporarily prevents Ctrl+C and delays it after the server call is complete.
@contextlib.contextmanager
def prevent_keyboard_interrupt():
    """Temporarily prevent keyboard interrupts by delaying them until after the protected code."""
    interrupted = False
    original_handler = signal.getsignal(signal.SIGINT)

    def handler(signum, frame):
        nonlocal interrupted
        interrupted = True

    signal.signal(signal.SIGINT, handler)
    try:
        yield
    finally:
        signal.signal(signal.SIGINT, original_handler)
        if interrupted:
            raise KeyboardInterrupt


def main(args: Args):
    # Make sure external camera is specified by user -- we only use one external camera for the policy
    assert (
        args.external_camera is not None and args.external_camera in ["left", "right"]
    ), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}"

    # Initialize the Panda environment. Using joint velocity action space and gripper position action space is very important.
    env = RobotEnv(action_space="joint_velocity", gripper_action_space="position")
    print("Created the droid env!")

    # Connect to the policy server
    policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port)

    df = pd.DataFrame(columns=["success", "duration", "video_filename"])

    while True:
        instruction = input("Enter instruction: ")

        # Rollout parameters
        actions_from_chunk_completed = 0
        pred_action_chunk = None

        # Prepare to save video of rollout
        timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S")
        video = []
        bar = tqdm.tqdm(range(args.max_timesteps))
        print("Running rollout... press Ctrl+C to stop early.")
        for t_step in bar:
            start_time = time.time()
            try:
                # Get the current observation
                curr_obs = _extract_observation(
                    args,
                    env.get_observation(),
                    # Save the first observation to disk
                    save_to_disk=t_step == 0,
                )

                video.append(curr_obs[f"{args.external_camera}_image"])

                # Send websocket request to policy server if it's time to predict a new chunk
                if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon:
                    actions_from_chunk_completed = 0

                    # We resize images on the robot laptop to minimize the amount of data sent to the policy server
                    # and improve latency.
                    request_data = {
                        "observation/exterior_image_1_left": image_tools.resize_with_pad(
                            curr_obs[f"{args.external_camera}_image"], 224, 224
                        ),
                        "observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224),
                        "observation/joint_position": curr_obs["joint_position"],
                        "observation/gripper_position": curr_obs["gripper_position"],
                        "prompt": instruction,
                    }

                    # Wrap the server call in a context manager to prevent Ctrl+C from interrupting it
                    # Ctrl+C will be handled after the server call is complete
                    with prevent_keyboard_interrupt():
                        # this returns action chunk [10, 8] of 10 joint velocity actions (7) + gripper position (1)
                        pred_action_chunk = policy_client.infer(request_data)["actions"]
                    assert pred_action_chunk.shape == (10, 8)

                # Select current action to execute from chunk
                action = pred_action_chunk[actions_from_chunk_completed]
                actions_from_chunk_completed += 1

                # Binarize gripper action
                if action[-1].item() > 0.5:
                    # action[-1] = 1.0
                    action = np.concatenate([action[:-1], np.ones((1,))])
                else:
                    # action[-1] = 0.0
                    action = np.concatenate([action[:-1], np.zeros((1,))])

                # clip all dimensions of action to [-1, 1]
                action = np.clip(action, -1, 1)

                env.step(action)

                # Sleep to match DROID data collection frequency
                elapsed_time = time.time() - start_time
                if elapsed_time < 1 / DROID_CONTROL_FREQUENCY:
                    time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time)
            except KeyboardInterrupt:
                break

        video = np.stack(video)
        save_filename = "video_" + timestamp
        ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + ".mp4", codec="libx264")

        success: str | float | None = None
        while not isinstance(success, float):
            success = input(
                "Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec"
            )
            if success == "y":
                success = 1.0
            elif success == "n":
                success = 0.0

            success = float(success) / 100
            if not (0 <= success <= 1):
                print(f"Success must be a number in [0, 100] but got: {success * 100}")

        df = df.append(
            {
                "success": success,
                "duration": t_step,
                "video_filename": save_filename,
            },
            ignore_index=True,
        )

        if input("Do one more eval? (enter y or n) ").lower() != "y":
            break
        env.reset()

    os.makedirs("results", exist_ok=True)
    timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y")
    csv_filename = os.path.join("results", f"eval_{timestamp}.csv")
    df.to_csv(csv_filename)
    print(f"Results saved to {csv_filename}")


def _extract_observation(args: Args, obs_dict, *, save_to_disk=False):
    image_observations = obs_dict["image"]
    left_image, right_image, wrist_image = None, None, None
    for key in image_observations:
        # Note the "left" below refers to the left camera in the stereo pair.
        # The model is only trained on left stereo cams, so we only feed those.
        if args.left_camera_id in key and "left" in key:
            left_image = image_observations[key]
        elif args.right_camera_id in key and "left" in key:
            right_image = image_observations[key]
        elif args.wrist_camera_id in key and "left" in key:
            wrist_image = image_observations[key]

    # Drop the alpha dimension
    left_image = left_image[..., :3]
    right_image = right_image[..., :3]
    wrist_image = wrist_image[..., :3]

    # Convert to RGB
    left_image = left_image[..., ::-1]
    right_image = right_image[..., ::-1]
    wrist_image = wrist_image[..., ::-1]

    # In addition to image observations, also capture the proprioceptive state
    robot_state = obs_dict["robot_state"]
    cartesian_position = np.array(robot_state["cartesian_position"])
    joint_position = np.array(robot_state["joint_positions"])
    gripper_position = np.array([robot_state["gripper_position"]])

    # Save the images to disk so that they can be viewed live while the robot is running
    # Create one combined image to make live viewing easy
    if save_to_disk:
        combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1)
        combined_image = Image.fromarray(combined_image)
        combined_image.save("robot_camera_views.png")

    return {
        "left_image": left_image,
        "right_image": right_image,
        "wrist_image": wrist_image,
        "cartesian_position": cartesian_position,
        "joint_position": joint_position,
        "gripper_position": gripper_position,
    }


if __name__ == "__main__":
    args: Args = tyro.cli(Args)
    main(args)


================================================
FILE: examples/inference.ipynb
Download .txt
gitextract_6rhyljlr/

├── .dockerignore
├── .github/
│   ├── CODEOWNERS
│   └── workflows/
│       ├── pre-commit.yml
│       └── test.yml
├── .gitignore
├── .gitmodules
├── .pre-commit-config.yaml
├── .python-version
├── .vscode/
│   └── settings.json
├── CONTRIBUTING.md
├── LICENSE
├── LICENSE_GEMMA.txt
├── README.md
├── docs/
│   ├── docker.md
│   ├── norm_stats.md
│   └── remote_inference.md
├── examples/
│   ├── aloha_real/
│   │   ├── Dockerfile
│   │   ├── README.md
│   │   ├── compose.yml
│   │   ├── constants.py
│   │   ├── convert_aloha_data_to_lerobot.py
│   │   ├── env.py
│   │   ├── main.py
│   │   ├── real_env.py
│   │   ├── requirements.in
│   │   ├── requirements.txt
│   │   ├── robot_utils.py
│   │   └── video_display.py
│   ├── aloha_sim/
│   │   ├── Dockerfile
│   │   ├── README.md
│   │   ├── compose.yml
│   │   ├── env.py
│   │   ├── main.py
│   │   ├── requirements.in
│   │   ├── requirements.txt
│   │   └── saver.py
│   ├── convert_jax_model_to_pytorch.py
│   ├── droid/
│   │   ├── README.md
│   │   ├── README_train.md
│   │   ├── compute_droid_nonidle_ranges.py
│   │   ├── convert_droid_data_to_lerobot.py
│   │   └── main.py
│   ├── inference.ipynb
│   ├── libero/
│   │   ├── Dockerfile
│   │   ├── README.md
│   │   ├── compose.yml
│   │   ├── convert_libero_data_to_lerobot.py
│   │   ├── main.py
│   │   ├── requirements.in
│   │   └── requirements.txt
│   ├── policy_records.ipynb
│   ├── simple_client/
│   │   ├── Dockerfile
│   │   ├── README.md
│   │   ├── compose.yml
│   │   ├── main.py
│   │   ├── requirements.in
│   │   └── requirements.txt
│   └── ur5/
│       └── README.md
├── packages/
│   └── openpi-client/
│       ├── pyproject.toml
│       └── src/
│           └── openpi_client/
│               ├── __init__.py
│               ├── action_chunk_broker.py
│               ├── base_policy.py
│               ├── image_tools.py
│               ├── image_tools_test.py
│               ├── msgpack_numpy.py
│               ├── msgpack_numpy_test.py
│               ├── runtime/
│               │   ├── agent.py
│               │   ├── agents/
│               │   │   └── policy_agent.py
│               │   ├── environment.py
│               │   ├── runtime.py
│               │   └── subscriber.py
│               └── websocket_client_policy.py
├── pyproject.toml
├── scripts/
│   ├── __init__.py
│   ├── compute_norm_stats.py
│   ├── docker/
│   │   ├── compose.yml
│   │   ├── install_docker_ubuntu22.sh
│   │   ├── install_nvidia_container_toolkit.sh
│   │   └── serve_policy.Dockerfile
│   ├── serve_policy.py
│   ├── train.py
│   ├── train_pytorch.py
│   └── train_test.py
└── src/
    └── openpi/
        ├── __init__.py
        ├── conftest.py
        ├── models/
        │   ├── __init__.py
        │   ├── gemma.py
        │   ├── gemma_fast.py
        │   ├── lora.py
        │   ├── lora_test.py
        │   ├── model.py
        │   ├── model_test.py
        │   ├── pi0.py
        │   ├── pi0_config.py
        │   ├── pi0_fast.py
        │   ├── pi0_test.py
        │   ├── siglip.py
        │   ├── tokenizer.py
        │   ├── tokenizer_test.py
        │   ├── utils/
        │   │   └── fsq_tokenizer.py
        │   └── vit.py
        ├── models_pytorch/
        │   ├── gemma_pytorch.py
        │   ├── pi0_pytorch.py
        │   ├── preprocessing_pytorch.py
        │   └── transformers_replace/
        │       └── models/
        │           ├── gemma/
        │           │   ├── configuration_gemma.py
        │           │   └── modeling_gemma.py
        │           ├── paligemma/
        │           │   └── modeling_paligemma.py
        │           └── siglip/
        │               ├── check.py
        │               └── modeling_siglip.py
        ├── policies/
        │   ├── aloha_policy.py
        │   ├── droid_policy.py
        │   ├── libero_policy.py
        │   ├── policy.py
        │   ├── policy_config.py
        │   └── policy_test.py
        ├── py.typed
        ├── serving/
        │   └── websocket_policy_server.py
        ├── shared/
        │   ├── __init__.py
        │   ├── array_typing.py
        │   ├── download.py
        │   ├── download_test.py
        │   ├── image_tools.py
        │   ├── image_tools_test.py
        │   ├── nnx_utils.py
        │   ├── normalize.py
        │   └── normalize_test.py
        ├── training/
        │   ├── checkpoints.py
        │   ├── config.py
        │   ├── data_loader.py
        │   ├── data_loader_test.py
        │   ├── droid_rlds_dataset.py
        │   ├── misc/
        │   │   ├── polaris_config.py
        │   │   └── roboarena_config.py
        │   ├── optimizer.py
        │   ├── sharding.py
        │   ├── utils.py
        │   └── weight_loaders.py
        ├── transforms.py
        └── transforms_test.py
Download .txt
SYMBOL INDEX (837 symbols across 84 files)

FILE: examples/aloha_real/convert_aloha_data_to_lerobot.py
  class DatasetConfig (line 23) | class DatasetConfig:
  function create_empty_dataset (line 34) | def create_empty_dataset(
  function get_cameras (line 128) | def get_cameras(hdf5_files: list[Path]) -> list[str]:
  function has_velocity (line 134) | def has_velocity(hdf5_files: list[Path]) -> bool:
  function has_effort (line 139) | def has_effort(hdf5_files: list[Path]) -> bool:
  function load_raw_images_per_camera (line 144) | def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dic...
  function load_raw_episode_data (line 165) | def load_raw_episode_data(
  function populate_dataset (line 193) | def populate_dataset(
  function port_aloha (line 229) | def port_aloha(

FILE: examples/aloha_real/env.py
  class AlohaRealEnvironment (line 11) | class AlohaRealEnvironment(_environment.Environment):
    method __init__ (line 14) | def __init__(
    method reset (line 27) | def reset(self) -> None:
    method is_episode_complete (line 31) | def is_episode_complete(self) -> bool:
    method get_observation (line 35) | def get_observation(self) -> dict:
    method apply_action (line 56) | def apply_action(self, action: dict) -> None:

FILE: examples/aloha_real/main.py
  class Args (line 14) | class Args:
  function main (line 24) | def main(args: Args) -> None:

FILE: examples/aloha_real/real_env.py
  class RealEnv (line 18) | class RealEnv:
    method __init__ (line 40) | def __init__(self, init_node, *, reset_position: Optional[List[float]]...
    method setup_robots (line 62) | def setup_robots(self):
    method get_qpos (line 66) | def get_qpos(self):
    method get_qvel (line 79) | def get_qvel(self):
    method get_effort (line 88) | def get_effort(self):
    method get_images (line 95) | def get_images(self):
    method set_gripper_pose (line 98) | def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_...
    method _reset_joints (line 109) | def _reset_joints(self):
    method _reset_gripper (line 114) | def _reset_gripper(self):
    method get_observation (line 128) | def get_observation(self):
    method get_reward (line 136) | def get_reward(self):
    method reset (line 139) | def reset(self, *, fake=False):
    method step (line 150) | def step(self, action):
  function get_action (line 163) | def get_action(master_bot_left, master_bot_right):
  function make_real_env (line 175) | def make_real_env(init_node, *, reset_position: Optional[List[float]] = ...

FILE: examples/aloha_real/robot_utils.py
  class ImageRecorder (line 19) | class ImageRecorder:
    method __init__ (line 20) | def __init__(self, init_node=True, is_debug=False):
    method image_cb (line 48) | def image_cb(self, cam_name, data):
    method image_cb_cam_high (line 72) | def image_cb_cam_high(self, data):
    method image_cb_cam_low (line 76) | def image_cb_cam_low(self, data):
    method image_cb_cam_left_wrist (line 80) | def image_cb_cam_left_wrist(self, data):
    method image_cb_cam_right_wrist (line 84) | def image_cb_cam_right_wrist(self, data):
    method get_images (line 88) | def get_images(self):
    method print_diagnostics (line 100) | def print_diagnostics(self):
  class Recorder (line 112) | class Recorder:
    method __init__ (line 113) | def __init__(self, side, init_node=True, is_debug=False):
    method puppet_state_cb (line 141) | def puppet_state_cb(self, data):
    method puppet_arm_commands_cb (line 149) | def puppet_arm_commands_cb(self, data):
    method puppet_gripper_commands_cb (line 154) | def puppet_gripper_commands_cb(self, data):
    method print_diagnostics (line 159) | def print_diagnostics(self):
  function get_arm_joint_positions (line 172) | def get_arm_joint_positions(bot):
  function get_arm_gripper_positions (line 176) | def get_arm_gripper_positions(bot):
  function move_arms (line 180) | def move_arms(bot_list, target_pose_list, move_time=1):
  function move_grippers (line 193) | def move_grippers(bot_list, target_pose_list, move_time):
  function setup_puppet_bot (line 214) | def setup_puppet_bot(bot):
  function setup_master_bot (line 221) | def setup_master_bot(bot):
  function set_standard_pid_gains (line 227) | def set_standard_pid_gains(bot):
  function set_low_pid_gains (line 232) | def set_low_pid_gains(bot):
  function torque_off (line 237) | def torque_off(bot):
  function torque_on (line 242) | def torque_on(bot):
  function sync_puppet_to_master (line 248) | def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_...

FILE: examples/aloha_real/video_display.py
  class VideoDisplay (line 7) | class VideoDisplay(_subscriber.Subscriber):
    method __init__ (line 10) | def __init__(self) -> None:
    method on_episode_start (line 15) | def on_episode_start(self) -> None:
    method on_step (line 21) | def on_step(self, observation: dict, action: dict) -> None:
    method on_episode_end (line 34) | def on_episode_end(self) -> None:

FILE: examples/aloha_sim/env.py
  class AlohaSimEnvironment (line 9) | class AlohaSimEnvironment(_environment.Environment):
    method __init__ (line 12) | def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed...
    method reset (line 23) | def reset(self) -> None:
    method is_episode_complete (line 30) | def is_episode_complete(self) -> bool:
    method get_observation (line 34) | def get_observation(self) -> dict:
    method apply_action (line 41) | def apply_action(self, action: dict) -> None:
    method _convert_observation (line 47) | def _convert_observation(self, gym_obs: dict) -> dict:

FILE: examples/aloha_sim/main.py
  class Args (line 15) | class Args:
  function main (line 29) | def main(args: Args) -> None:

FILE: examples/aloha_sim/saver.py
  class VideoSaver (line 10) | class VideoSaver(_subscriber.Subscriber):
    method __init__ (line 13) | def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None:
    method on_episode_start (line 20) | def on_episode_start(self) -> None:
    method on_step (line 24) | def on_step(self, observation: dict, action: dict) -> None:
    method on_episode_end (line 30) | def on_episode_end(self) -> None:

FILE: examples/convert_jax_model_to_pytorch.py
  function slice_paligemma_state_dict (line 50) | def slice_paligemma_state_dict(state_dict, config):
  function slice_gemma_state_dict (line 271) | def slice_gemma_state_dict(state_dict, config, *, num_expert, checkpoint...
  function slice_initial_orbax_checkpoint (line 396) | def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precisio...
  function load_jax_model_and_print_keys (line 408) | def load_jax_model_and_print_keys(checkpoint_dir: str):
  function convert_pi0_checkpoint (line 422) | def convert_pi0_checkpoint(
  function main (line 558) | def main(

FILE: examples/droid/convert_droid_data_to_lerobot.py
  function resize_image (line 32) | def resize_image(image, size):
  function main (line 37) | def main(data_dir: str, *, push_to_hub: bool = False):
  function get_camera_type (line 180) | def get_camera_type(cam_id):
  class MP4Reader (line 187) | class MP4Reader:
    method __init__ (line 188) | def __init__(self, filepath, serial_number):
    method set_reading_parameters (line 198) | def set_reading_parameters(
    method get_frame_resolution (line 214) | def get_frame_resolution(self):
    method get_frame_count (line 219) | def get_frame_count(self):
    method set_frame_index (line 224) | def set_frame_index(self, index):
    method _process_frame (line 235) | def _process_frame(self, frame):
    method read_camera (line 241) | def read_camera(self, ignore_data=False, correct_timestamp=None):  # n...
    method disable_camera (line 269) | def disable_camera(self):
  class RecordedMultiCameraWrapper (line 274) | class RecordedMultiCameraWrapper:
    method __init__ (line 275) | def __init__(self, recording_folderpath, camera_kwargs={}):  # noqa: B006
    method read_cameras (line 296) | def read_cameras(self, index=None, camera_type_dict={}, timestamp_dict...
  function get_hdf5_length (line 329) | def get_hdf5_length(hdf5_file, keys_to_ignore=[]):  # noqa: B006
  function load_hdf5_to_dict (line 351) | def load_hdf5_to_dict(hdf5_file, index, keys_to_ignore=[]):  # noqa: B006
  class TrajectoryReader (line 369) | class TrajectoryReader:
    method __init__ (line 370) | def __init__(self, filepath, read_images=True):  # noqa: FBT002
    method length (line 378) | def length(self):
    method read_timestep (line 381) | def read_timestep(self, index=None, keys_to_ignore=[]):  # noqa: B006
    method close (line 400) | def close(self):
  function load_trajectory (line 404) | def load_trajectory(

FILE: examples/droid/main.py
  class Args (line 27) | class Args:
  function prevent_keyboard_interrupt (line 55) | def prevent_keyboard_interrupt():
  function main (line 73) | def main(args: Args):
  function _extract_observation (line 198) | def _extract_observation(args: Args, obs_dict, *, save_to_disk=False):

FILE: examples/libero/convert_libero_data_to_lerobot.py
  function main (line 37) | def main(data_dir: str, *, push_to_hub: bool = False):

FILE: examples/libero/main.py
  class Args (line 22) | class Args:
  function eval_libero (line 48) | def eval_libero(args: Args) -> None:
  function _get_libero_env (line 189) | def _get_libero_env(task, resolution, seed):
  function _quat2axisangle (line 199) | def _quat2axisangle(quat):

FILE: examples/simple_client/main.py
  class EnvMode (line 17) | class EnvMode(enum.Enum):
  class Args (line 27) | class Args:
  class TimingRecorder (line 44) | class TimingRecorder:
    method __init__ (line 47) | def __init__(self) -> None:
    method record (line 50) | def record(self, key: str, time_ms: float) -> None:
    method get_stats (line 56) | def get_stats(self, key: str) -> dict[str, float]:
    method print_all_stats (line 70) | def print_all_stats(self) -> None:
    method write_parquet (line 109) | def write_parquet(self, path: pathlib.Path) -> None:
  function main (line 117) | def main(args: Args) -> None:
  function _random_observation_aloha (line 153) | def _random_observation_aloha() -> dict:
  function _random_observation_droid (line 166) | def _random_observation_droid() -> dict:
  function _random_observation_libero (line 176) | def _random_observation_libero() -> dict:

FILE: packages/openpi-client/src/openpi_client/action_chunk_broker.py
  class ActionChunkBroker (line 10) | class ActionChunkBroker(_base_policy.BasePolicy):
    method __init__ (line 19) | def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
    method infer (line 27) | def infer(self, obs: Dict) -> Dict:  # noqa: UP006
    method reset (line 47) | def reset(self) -> None:

FILE: packages/openpi-client/src/openpi_client/base_policy.py
  class BasePolicy (line 5) | class BasePolicy(abc.ABC):
    method infer (line 7) | def infer(self, obs: Dict) -> Dict:
    method reset (line 10) | def reset(self) -> None:

FILE: packages/openpi-client/src/openpi_client/image_tools.py
  function convert_to_uint8 (line 5) | def convert_to_uint8(img: np.ndarray) -> np.ndarray:
  function resize_with_pad (line 15) | def resize_with_pad(images: np.ndarray, height: int, width: int, method=...
  function _resize_with_pad_pil (line 38) | def _resize_with_pad_pil(image: Image.Image, height: int, width: int, me...

FILE: packages/openpi-client/src/openpi_client/image_tools_test.py
  function test_resize_with_pad_shapes (line 6) | def test_resize_with_pad_shapes():

FILE: packages/openpi-client/src/openpi_client/msgpack_numpy.py
  function pack_array (line 21) | def pack_array(obj):
  function unpack_array (line 43) | def unpack_array(obj):

FILE: packages/openpi-client/src/openpi_client/msgpack_numpy_test.py
  function _check (line 8) | def _check(expected, actual):
  function test_pack_unpack (line 42) | def test_pack_unpack(data):

FILE: packages/openpi-client/src/openpi_client/runtime/agent.py
  class Agent (line 4) | class Agent(abc.ABC):
    method get_action (line 12) | def get_action(self, observation: dict) -> dict:
    method reset (line 16) | def reset(self) -> None:

FILE: packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py
  class PolicyAgent (line 7) | class PolicyAgent(_agent.Agent):
    method __init__ (line 10) | def __init__(self, policy: _base_policy.BasePolicy) -> None:
    method get_action (line 14) | def get_action(self, observation: dict) -> dict:
    method reset (line 17) | def reset(self) -> None:

FILE: packages/openpi-client/src/openpi_client/runtime/environment.py
  class Environment (line 4) | class Environment(abc.ABC):
    method reset (line 12) | def reset(self) -> None:
    method is_episode_complete (line 19) | def is_episode_complete(self) -> bool:
    method get_observation (line 27) | def get_observation(self) -> dict:
    method apply_action (line 31) | def apply_action(self, action: dict) -> None:

FILE: packages/openpi-client/src/openpi_client/runtime/runtime.py
  class Runtime (line 10) | class Runtime:
    method __init__ (line 13) | def __init__(
    method run (line 32) | def run(self) -> None:
    method run_in_new_thread (line 40) | def run_in_new_thread(self) -> threading.Thread:
    method mark_episode_complete (line 46) | def mark_episode_complete(self) -> None:
    method _run_episode (line 50) | def _run_episode(self) -> None:
    method _step (line 80) | def _step(self) -> None:

FILE: packages/openpi-client/src/openpi_client/runtime/subscriber.py
  class Subscriber (line 4) | class Subscriber(abc.ABC):
    method on_episode_start (line 11) | def on_episode_start(self) -> None:
    method on_step (line 15) | def on_step(self, observation: dict, action: dict) -> None:
    method on_episode_end (line 19) | def on_episode_end(self) -> None:

FILE: packages/openpi-client/src/openpi_client/websocket_client_policy.py
  class WebsocketClientPolicy (line 12) | class WebsocketClientPolicy(_base_policy.BasePolicy):
    method __init__ (line 18) | def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, ...
    method get_server_metadata (line 29) | def get_server_metadata(self) -> Dict:
    method _wait_for_server (line 32) | def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConne...
    method infer (line 47) | def infer(self, obs: Dict) -> Dict:  # noqa: UP006
    method reset (line 57) | def reset(self) -> None:

FILE: scripts/compute_norm_stats.py
  class RemoveStrings (line 19) | class RemoveStrings(transforms.DataTransformFn):
    method __call__ (line 20) | def __call__(self, x: dict) -> dict:
  function create_torch_dataloader (line 24) | def create_torch_dataloader(
  function create_rlds_dataloader (line 60) | def create_rlds_dataloader(
  function main (line 89) | def main(config_name: str, max_frames: int | None = None):

FILE: scripts/serve_policy.py
  class EnvMode (line 14) | class EnvMode(enum.Enum):
  class Checkpoint (line 24) | class Checkpoint:
  class Default (line 34) | class Default:
  class Args (line 39) | class Args:
  function create_default_policy (line 79) | def create_default_policy(env: EnvMode, *, default_prompt: str | None = ...
  function create_policy (line 88) | def create_policy(args: Args) -> _policy.Policy:
  function main (line 99) | def main(args: Args) -> None:

FILE: scripts/train.py
  function init_logging (line 31) | def init_logging():
  function init_wandb (line 50) | def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code:...
  function _load_weights_and_validate (line 73) | def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, par...
  function init_train_state (line 85) | def init_train_state(
  function train_step (line 137) | def train_step(
  function main (line 194) | def main(config: _config.TrainConfig):

FILE: scripts/train_pytorch.py
  function init_logging (line 50) | def init_logging():
  function init_wandb (line 72) | def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: ...
  function setup_ddp (line 94) | def setup_ddp():
  function cleanup_ddp (line 112) | def cleanup_ddp():
  function set_seed (line 118) | def set_seed(seed: int, local_rank: int):
  function build_datasets (line 125) | def build_datasets(config: _config.TrainConfig):
  function get_model_state_dict (line 131) | def get_model_state_dict(model):
  function get_model_parameters (line 140) | def get_model_parameters(model):
  function save_checkpoint (line 149) | def save_checkpoint(model, optimizer, global_step, config, is_main, data...
  function load_checkpoint (line 197) | def load_checkpoint(model, optimizer, checkpoint_dir, device):
  function get_latest_checkpoint_step (line 274) | def get_latest_checkpoint_step(checkpoint_dir):
  function log_memory_usage (line 284) | def log_memory_usage(device, step, phase="unknown"):
  function train_loop (line 309) | def train_loop(config: _config.TrainConfig):
  function main (line 625) | def main():

FILE: scripts/train_test.py
  function test_train (line 15) | def test_train(tmp_path: pathlib.Path, config_name: str):

FILE: src/openpi/conftest.py
  function set_jax_cpu_backend_if_no_gpu (line 7) | def set_jax_cpu_backend_if_no_gpu() -> None:
  function pytest_configure (line 16) | def pytest_configure(config: pytest.Config) -> None:

FILE: src/openpi/models/gemma.py
  class Config (line 45) | class Config:
  function get_config (line 58) | def get_config(variant: Variant) -> Config:
  class RMSNorm (line 113) | class RMSNorm(nn.Module):
    method __call__ (line 115) | def __call__(self, x, cond):
  class Embedder (line 135) | class Embedder(nn.Module):
    method setup (line 141) | def setup(self):
    method encode (line 148) | def encode(self, x):
    method decode (line 153) | def decode(self, x):
  class Attention (line 158) | class Attention(nn.Module):
    method __call__ (line 164) | def __call__(self, xs, positions, attn_mask, kv_cache):
  class FeedForward (line 253) | class FeedForward(nn.Module):
    method __call__ (line 260) | def __call__(self, x):
  class Block (line 284) | class Block(nn.Module):
    method __call__ (line 293) | def __call__(self, xs, kv_cache, positions, attn_mask, adarms_cond, de...
  class Module (line 340) | class Module(nn.Module):
    method setup (line 350) | def setup(self):
    method embed (line 385) | def embed(self, tokens: at.Int[at.Array, "b t"]) -> at.Float[at.Array,...
    method __call__ (line 389) | def __call__(
    method init (line 413) | def init(self, use_adarms: Sequence[bool]):
  function _apply_rope (line 424) | def _apply_rope(x, *, positions, max_wavelength=10_000):
  function _name (line 443) | def _name(name, i):
  function _gated_residual (line 453) | def _gated_residual(x, y, gate):

FILE: src/openpi/models/gemma_fast.py
  function get_config (line 35) | def get_config(variant):
  class Einsum (line 77) | class Einsum(nn.Module):
    method __call__ (line 81) | def __call__(self, eqn, x):
  class RMSNorm (line 88) | class RMSNorm(nn.Module):
    method __call__ (line 90) | def __call__(self, x):
  class Embedder (line 102) | class Embedder(nn.Module):
    method setup (line 108) | def setup(self):
    method encode (line 115) | def encode(self, x):
    method decode (line 120) | def decode(self, x):
  class Attention (line 125) | class Attention(nn.Module):
    method setup (line 137) | def setup(self):
    method _init_cache (line 165) | def _init_cache(self, k, v, cache_size):
    method _update_cache (line 175) | def _update_cache(self, k, v, idx, k_cache, v_cache):
    method __call__ (line 186) | def __call__(self, x, positions, attn_mask, kv_cache, decode, determin...
  class Block (line 228) | class Block(nn.Module):
    method setup (line 242) | def setup(self):
    method __call__ (line 261) | def __call__(self, x, kv_cache, positions, attn_mask, decode, determin...
  class Module (line 279) | class Module(nn.Module):
    method __call__ (line 303) | def __call__(
    method init (line 420) | def init(self):
  function _apply_rope (line 425) | def _apply_rope(x, *, positions, max_wavelength=10_000):

FILE: src/openpi/models/lora.py
  class LoRAConfig (line 12) | class LoRAConfig:
    method scaling_value (line 29) | def scaling_value(self) -> float:
  class Einsum (line 33) | class Einsum(nn.Module):
    method setup (line 43) | def setup(self):
    method __call__ (line 55) | def __call__(self, eqn: str, x):
    method _make_lora_eqns (line 67) | def _make_lora_eqns(self, eqn: str) -> tuple[str, str]:
  class FeedForward (line 88) | class FeedForward(nn.Module):
    method setup (line 96) | def setup(self):
    method __call__ (line 124) | def __call__(self, x):
    method _dot (line 144) | def _dot(self, x: at.Array, w: at.Array, lora_weights: tuple[at.Array,...

FILE: src/openpi/models/lora_test.py
  function test_lora_einsum_params_shape (line 8) | def test_lora_einsum_params_shape():
  function test_lora_einsum_same_output (line 34) | def test_lora_einsum_same_output():
  function test_lora_ffn_params_shape (line 53) | def test_lora_ffn_params_shape():
  function test_lora_ffn_same_output (line 77) | def test_lora_ffn_same_output():

FILE: src/openpi/models/model.py
  class ModelType (line 30) | class ModelType(enum.Enum):
  class Observation (line 83) | class Observation(Generic[ArrayT]):
    method from_dict (line 110) | def from_dict(cls, data: at.PyTree[ArrayT]) -> "Observation[ArrayT]":
    method to_dict (line 131) | def to_dict(self) -> at.PyTree[ArrayT]:
  function preprocess_observation (line 144) | def preprocess_observation(
  class BaseModelConfig (line 212) | class BaseModelConfig(abc.ABC):
    method model_type (line 226) | def model_type(self) -> ModelType:
    method create (line 230) | def create(self, rng: at.KeyArrayLike) -> "BaseModel":
    method load (line 233) | def load(self, params: at.Params, *, remove_extra_params: bool = True)...
    method load_pytorch (line 243) | def load_pytorch(self, train_config, weight_path: str):
    method inputs_spec (line 250) | def inputs_spec(self, *, batch_size: int = 1) -> tuple[Observation, Ac...
    method fake_obs (line 253) | def fake_obs(self, batch_size: int = 1) -> Observation:
    method fake_act (line 257) | def fake_act(self, batch_size: int = 1) -> Actions:
  class BaseModel (line 263) | class BaseModel(nnx.Module, abc.ABC):
    method compute_loss (line 273) | def compute_loss(
    method sample_actions (line 283) | def sample_actions(self, rng: at.KeyArrayLike, observation: Observatio...
  function restore_params (line 286) | def restore_params(

FILE: src/openpi/models/model_test.py
  function test_pi0_model (line 12) | def test_pi0_model():
  function test_pi0_lora_model (line 27) | def test_pi0_lora_model():
  function test_pi0_fast_model (line 42) | def test_pi0_fast_model():
  function test_pi0_fast_lora_model (line 57) | def test_pi0_fast_lora_model():
  function test_model_restore (line 79) | def test_model_restore():

FILE: src/openpi/models/pi0.py
  function make_attn_mask (line 19) | def make_attn_mask(input_mask, mask_ar):
  function posemb_sincos (line 48) | def posemb_sincos(
  class Pi0 (line 66) | class Pi0(_model.BaseModel):
    method __init__ (line 67) | def __init__(self, config: pi0_config.Pi0Config, rngs: nnx.Rngs):
    method embed_prefix (line 106) | def embed_prefix(
    method embed_suffix (line 140) | def embed_suffix(
    method compute_loss (line 189) | def compute_loss(
    method sample_actions (line 217) | def sample_actions(

FILE: src/openpi/models/pi0_config.py
  class Pi0Config (line 19) | class Pi0Config(_model.BaseModelConfig):
    method __post_init__ (line 37) | def __post_init__(self):
    method model_type (line 52) | def model_type(self) -> _model.ModelType:
    method create (line 58) | def create(self, rng: at.KeyArrayLike) -> "Pi0":
    method inputs_spec (line 64) | def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observat...
    method get_freeze_filter (line 88) | def get_freeze_filter(self) -> nnx.filterlib.Filter:

FILE: src/openpi/models/pi0_fast.py
  function make_attn_mask (line 23) | def make_attn_mask(input_mask, mask_ar):
  function left_to_right_align (line 52) | def left_to_right_align(x, input_mask, attn_mask):
  function put_along_last_axis (line 67) | def put_along_last_axis(arr, indices, values):
  class Pi0FASTConfig (line 77) | class Pi0FASTConfig(_model.BaseModelConfig):
    method model_type (line 93) | def model_type(self) -> _model.ModelType:
    method create (line 97) | def create(self, rng: at.KeyArrayLike) -> "Pi0FAST":
    method inputs_spec (line 101) | def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observat...
    method get_freeze_filter (line 127) | def get_freeze_filter(self) -> nnx.filterlib.Filter:
  class Pi0FAST (line 134) | class Pi0FAST(_model.BaseModel):
    method __init__ (line 135) | def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs):
    method embed_inputs (line 160) | def embed_inputs(
    method compute_loss (line 198) | def compute_loss(
    method sample_actions (line 236) | def sample_actions(

FILE: src/openpi/models/pi0_test.py
  function _get_frozen_state (line 7) | def _get_frozen_state(config: _pi0_config.Pi0Config) -> nnx.State:
  function test_pi0_full_finetune (line 14) | def test_pi0_full_finetune():
  function test_pi0_gemma_lora (line 20) | def test_pi0_gemma_lora():
  function test_pi0_action_expert_lora (line 29) | def test_pi0_action_expert_lora():
  function test_pi0_all_lora (line 40) | def test_pi0_all_lora():

FILE: src/openpi/models/siglip.py
  function posemb_sincos_2d (line 27) | def posemb_sincos_2d(h, w, width, temperature=10_000.0, dtype=jnp.float32):
  function get_posemb (line 40) | def get_posemb(self, typ, seqshape, width, name, dtype=jnp.float32):
  class MlpBlock (line 53) | class MlpBlock(nn.Module):
    method __call__ (line 61) | def __call__(self, x, deterministic=True):  # noqa: FBT002
  class Encoder1DBlock (line 75) | class Encoder1DBlock(nn.Module):
    method __call__ (line 84) | def __call__(self, x, deterministic=True):  # noqa: FBT002
  class Encoder (line 111) | class Encoder(nn.Module):
    method __call__ (line 123) | def __call__(self, x, deterministic=True):  # noqa: FBT002
  class MAPHead (line 164) | class MAPHead(nn.Module):
    method __call__ (line 172) | def __call__(self, x):
  class _Module (line 188) | class _Module(nn.Module):
    method __call__ (line 208) | def __call__(self, image, *, train=False):
  function Module (line 293) | def Module(num_classes=None, *, variant=None, **kw):  # pylint: disable=...
  function decode_variant (line 298) | def decode_variant(variant):

FILE: src/openpi/models/tokenizer.py
  class PaligemmaTokenizer (line 14) | class PaligemmaTokenizer:
    method __init__ (line 15) | def __init__(self, max_len: int = 48):
    method tokenize (line 22) | def tokenize(self, prompt: str, state: np.ndarray | None = None) -> tu...
  class FASTTokenizer (line 51) | class FASTTokenizer:
    method __init__ (line 52) | def __init__(self, max_len: int = 256, fast_tokenizer_path: str = "phy...
    method tokenize (line 64) | def tokenize(
    method extract_actions (line 119) | def extract_actions(self, tokens: np.ndarray, action_horizon: int, act...
    method _act_tokens_to_paligemma_tokens (line 136) | def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[in...
  class BinningTokenizer (line 148) | class BinningTokenizer:
    method __init__ (line 153) | def __init__(self, max_len: int = 256, n_bins: int = 256):
    method tokenize (line 164) | def tokenize(
    method extract_actions (line 222) | def extract_actions(self, tokens: np.ndarray, action_horizon: int, act...
    method _act_tokens_to_paligemma_tokens (line 240) | def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[in...
  class FSQTokenizer (line 246) | class FSQTokenizer:
    method __init__ (line 251) | def __init__(self, max_len: int = 256, fsq_tokenizer_path: str | None ...
    method tokenize (line 300) | def tokenize(
    method extract_actions (line 345) | def extract_actions(self, tokens: np.ndarray, action_horizon: int, act...
    method _act_tokens_to_paligemma_tokens (line 368) | def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[in...

FILE: src/openpi/models/tokenizer_test.py
  function test_tokenize (line 6) | def test_tokenize():
  function test_fast_tokenizer (line 14) | def test_fast_tokenizer():

FILE: src/openpi/models/utils/fsq_tokenizer.py
  class FsqCodebook (line 15) | class FsqCodebook(nn.Module):
    method bins_per_dim (line 23) | def bins_per_dim(self) -> tuple[int]:
    method place_values (line 37) | def place_values(self) -> jnp.ndarray:
    method _get_bins_fsq (line 44) | def _get_bins_fsq(target_codebook_size: int) -> tuple[int]:
    method _get_bins_custom (line 62) | def _get_bins_custom(target_codebook_size: int) -> tuple[int]:
    method _get_bins_lfq (line 76) | def _get_bins_lfq(target_codebook_size: int) -> tuple[int]:
    method setup (line 84) | def setup(self):
    method __call__ (line 88) | def __call__(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndar...
    method encode (line 93) | def encode(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
    method decode (line 105) | def decode(self, tokens: jnp.ndarray, z_grad: jax.Array | None = None)...
    method undigitize (line 117) | def undigitize(self, digits: jnp.ndarray) -> jnp.ndarray:
    method digitize (line 120) | def digitize(self, tokens: jnp.ndarray) -> jnp.ndarray:
    method vocab_size (line 124) | def vocab_size(self) -> int:
  class ResNetDownBlock (line 128) | class ResNetDownBlock(nn.Module):
    method __call__ (line 135) | def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray:
  class ResNetUpBlock (line 150) | class ResNetUpBlock(nn.Module):
    method __call__ (line 157) | def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray:
  class LfqCodebookOutput (line 173) | class LfqCodebookOutput:
  class LookupFreeQuantization (line 181) | class LookupFreeQuantization(nn.Module):
    method setup (line 185) | def setup(self):
    method encode (line 192) | def encode(self, z: jnp.ndarray) -> jnp.ndarray:
    method decode (line 198) | def decode(self, tokens: jnp.ndarray) -> jnp.ndarray:
    method loss (line 202) | def loss(self, x: jnp.ndarray) -> LfqCodebookOutput:
  function make_block_causal_attention_matrix (line 238) | def make_block_causal_attention_matrix(q: jnp.ndarray, k: jnp.ndarray, b...
  class GeGLU (line 242) | class GeGLU(Module):
    method __call__ (line 255) | def __call__(self, inputs: Array) -> Array:
  class CrossAttentionLayer (line 269) | class CrossAttentionLayer(nn.Module):
    method __call__ (line 276) | def __call__(
  function sinusoidal_pe_init (line 327) | def sinusoidal_pe_init(_, shape: tuple[int, int]) -> jnp.ndarray:
  class TokenizerEncoderDecoder (line 341) | class TokenizerEncoderDecoder(nn.Module):
    method __call__ (line 351) | def __call__(
  class FsqAttentionTokenizer (line 385) | class FsqAttentionTokenizer(nn.Module):
    method vocab_size (line 400) | def vocab_size(self) -> int:
    method setup (line 403) | def setup(self):
    method tokenize (line 430) | def tokenize(
    method detokenize (line 441) | def detokenize(self, tokens: jnp.ndarray, *, obs: jnp.ndarray | None =...
    method loss (line 446) | def loss(
    method __call__ (line 468) | def __call__(self, *args: Any, **kwargs: Any) -> tuple[jnp.ndarray, di...

FILE: src/openpi/models/vit.py
  class IdentityLayer (line 31) | class IdentityLayer(nn.Module):
    method __call__ (line 35) | def __call__(self, x):
  class AddPositionEmbs (line 39) | class AddPositionEmbs(nn.Module):
    method __call__ (line 50) | def __call__(self, inputs):
  class MlpBlock (line 66) | class MlpBlock(nn.Module):
    method __call__ (line 78) | def __call__(self, inputs, *, deterministic):
  class Encoder1DBlock (line 104) | class Encoder1DBlock(nn.Module):
    method __call__ (line 124) | def __call__(self, inputs, deterministic):
  class Encoder (line 160) | class Encoder(nn.Module):
    method __call__ (line 180) | def __call__(self, x, *, train):
  class VisionTransformer (line 219) | class VisionTransformer(nn.Module):
    method __call__ (line 235) | def __call__(self, inputs, *, train):

FILE: src/openpi/models_pytorch/gemma_pytorch.py
  class PaliGemmaWithExpertModel (line 12) | class PaliGemmaWithExpertModel(nn.Module):
    method __init__ (line 13) | def __init__(
    method to_bfloat16_for_selected_params (line 63) | def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16...
    method embed_image (line 85) | def embed_image(self, image: torch.Tensor):
    method embed_language_tokens (line 88) | def embed_language_tokens(self, tokens: torch.Tensor):
    method forward (line 91) | def forward(

FILE: src/openpi/models_pytorch/pi0_pytorch.py
  function get_safe_dtype (line 14) | def get_safe_dtype(target_dtype, device_type):
  function create_sinusoidal_pos_embedding (line 25) | def create_sinusoidal_pos_embedding(
  function sample_beta (line 45) | def sample_beta(alpha, beta, bsize, device):
  function make_att_2d_masks (line 52) | def make_att_2d_masks(pad_masks, att_masks):
  class PI0Pytorch (line 84) | class PI0Pytorch(nn.Module):
    method __init__ (line 85) | def __init__(self, config):
    method gradient_checkpointing_enable (line 127) | def gradient_checkpointing_enable(self):
    method gradient_checkpointing_disable (line 136) | def gradient_checkpointing_disable(self):
    method is_gradient_checkpointing_enabled (line 145) | def is_gradient_checkpointing_enabled(self):
    method _apply_checkpoint (line 149) | def _apply_checkpoint(self, func, *args, **kwargs):
    method _prepare_attention_masks_4d (line 157) | def _prepare_attention_masks_4d(self, att_2d_masks):
    method _preprocess_observation (line 162) | def _preprocess_observation(self, observation, *, train=True):
    method sample_noise (line 173) | def sample_noise(self, shape, device):
    method sample_time (line 182) | def sample_time(self, bsize, device):
    method embed_prefix (line 187) | def embed_prefix(
    method embed_suffix (line 238) | def embed_suffix(self, state, noisy_actions, timestep):
    method forward (line 317) | def forward(self, observation, actions, noise=None, time=None) -> Tensor:
    method sample_actions (line 377) | def sample_actions(self, device, observation, noise=None, num_steps=10...
    method denoise_step (line 422) | def denoise_step(

FILE: src/openpi/models_pytorch/preprocessing_pytorch.py
  function preprocess_observation_pytorch (line 20) | def preprocess_observation_pytorch(

FILE: src/openpi/models_pytorch/transformers_replace/models/gemma/configuration_gemma.py
  class GemmaConfig (line 26) | class GemmaConfig(PretrainedConfig):
    method __init__ (line 115) | def __init__(

FILE: src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py
  class GemmaRMSNorm (line 49) | class GemmaRMSNorm(nn.Module):
    method __init__ (line 50) | def __init__(self, dim: int, eps: float = 1e-6, cond_dim: Optional[int...
    method _norm (line 66) | def _norm(self, x):
    method forward (line 73) | def forward(self, x, cond=None):
    method extra_repr (line 106) | def extra_repr(self):
  class GemmaMLP (line 113) | class GemmaMLP(nn.Module):
    method __init__ (line 114) | def __init__(self, config):
    method forward (line 124) | def forward(self, x):
  class GemmaRotaryEmbedding (line 129) | class GemmaRotaryEmbedding(nn.Module):
    method __init__ (line 130) | def __init__(self, config: GemmaConfig, device=None):
    method forward (line 149) | def forward(self, x, position_ids):
  function rotate_half (line 163) | def rotate_half(x):
  function apply_rotary_pos_emb (line 170) | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_di...
  function repeat_kv (line 197) | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  function _gated_residual (line 209) | def _gated_residual(x, y, gate):
  function eager_attention_forward (line 230) | def eager_attention_forward(
  class GemmaAttention (line 256) | class GemmaAttention(nn.Module):
    method __init__ (line 259) | def __init__(self, config: GemmaConfig, layer_idx: int):
    method forward (line 282) | def forward(
  class GemmaDecoderLayer (line 332) | class GemmaDecoderLayer(GradientCheckpointingLayer):
    method __init__ (line 333) | def __init__(self, config: GemmaConfig, layer_idx: int):
    method forward (line 344) | def forward(
  class GemmaPreTrainedModel (line 388) | class GemmaPreTrainedModel(PreTrainedModel):
    method _init_weights (line 403) | def _init_weights(self, module):
  class GemmaModel (line 419) | class GemmaModel(GemmaPreTrainedModel):
    method __init__ (line 420) | def __init__(self, config: GemmaConfig):
    method get_input_embeddings (line 438) | def get_input_embeddings(self):
    method set_input_embeddings (line 441) | def set_input_embeddings(self, value):
    method forward (line 446) | def forward(
  class KwargsForCausalLM (line 558) | class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
  class GemmaForCausalLM (line 562) | class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
    method __init__ (line 567) | def __init__(self, config):
    method get_input_embeddings (line 576) | def get_input_embeddings(self):
    method set_input_embeddings (line 579) | def set_input_embeddings(self, value):
    method get_output_embeddings (line 582) | def get_output_embeddings(self):
    method set_output_embeddings (line 585) | def set_output_embeddings(self, new_embeddings):
    method set_decoder (line 588) | def set_decoder(self, decoder):
    method get_decoder (line 591) | def get_decoder(self):
    method forward (line 596) | def forward(
  class GemmaForSequenceClassification (line 689) | class GemmaForSequenceClassification(GemmaPreTrainedModel):
    method __init__ (line 690) | def __init__(self, config):
    method get_input_embeddings (line 699) | def get_input_embeddings(self):
    method set_input_embeddings (line 702) | def set_input_embeddings(self, value):
    method forward (line 707) | def forward(
  class GemmaForTokenClassification (line 781) | class GemmaForTokenClassification(GemmaPreTrainedModel):
    method __init__ (line 782) | def __init__(self, config):
    method get_input_embeddings (line 798) | def get_input_embeddings(self):
    method set_input_embeddings (line 801) | def set_input_embeddings(self, value):
    method forward (line 806) | def forward(

FILE: src/openpi/models_pytorch/transformers_replace/models/paligemma/modeling_paligemma.py
  class PaligemmaModelOutputWithPast (line 44) | class PaligemmaModelOutputWithPast(BaseModelOutputWithPast):
  class PaliGemmaCausalLMOutputWithPast (line 66) | class PaliGemmaCausalLMOutputWithPast(ModelOutput):
  class PaliGemmaMultiModalProjector (line 91) | class PaliGemmaMultiModalProjector(nn.Module):
    method __init__ (line 92) | def __init__(self, config: PaliGemmaConfig):
    method forward (line 96) | def forward(self, image_features):
  class PaliGemmaPreTrainedModel (line 103) | class PaliGemmaPreTrainedModel(PreTrainedModel):
    method _init_weights (line 117) | def _init_weights(self, module):
  class PaliGemmaModel (line 133) | class PaliGemmaModel(PaliGemmaPreTrainedModel):
    method __init__ (line 138) | def __init__(self, config: PaliGemmaConfig):
    method get_input_embeddings (line 151) | def get_input_embeddings(self):
    method set_input_embeddings (line 155) | def set_input_embeddings(self, value):
    method set_decoder (line 158) | def set_decoder(self, decoder):
    method get_decoder (line 161) | def get_decoder(self):
    method _update_causal_mask (line 164) | def _update_causal_mask(
    method get_image_features (line 232) | def get_image_features(self, pixel_values: torch.FloatTensor):
    method forward (line 249) | def forward(
  class KwargsForCausalLM (line 372) | class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
  class PaliGemmaForConditionalGeneration (line 380) | class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, Genera...
    method __init__ (line 389) | def __init__(self, config: PaliGemmaConfig):
    method get_input_embeddings (line 395) | def get_input_embeddings(self):
    method set_input_embeddings (line 398) | def set_input_embeddings(self, value):
    method get_output_embeddings (line 401) | def get_output_embeddings(self):
    method set_output_embeddings (line 404) | def set_output_embeddings(self, new_embeddings):
    method set_decoder (line 407) | def set_decoder(self, decoder):
    method get_decoder (line 410) | def get_decoder(self):
    method get_image_features (line 413) | def get_image_features(self, pixel_values):
    method language_model (line 418) | def language_model(self):
    method vision_tower (line 422) | def vision_tower(self):
    method multi_modal_projector (line 426) | def multi_modal_projector(self):
    method forward (line 431) | def forward(
    method prepare_inputs_for_generation (line 519) | def prepare_inputs_for_generation(
    method _prepare_4d_causal_attention_mask_with_cache_position (line 567) | def _prepare_4d_causal_attention_mask_with_cache_position(

FILE: src/openpi/models_pytorch/transformers_replace/models/siglip/check.py
  function check_whether_transformers_replace_is_installed_correctly (line 3) | def check_whether_transformers_replace_is_installed_correctly():

FILE: src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py
  function _trunc_normal_ (line 41) | def _trunc_normal_(tensor, mean, std, a, b):
  function trunc_normal_tf_ (line 77) | def trunc_normal_tf_(
  function variance_scaling_ (line 103) | def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="no...
  function lecun_normal_ (line 128) | def lecun_normal_(tensor):
  function default_flax_embed_init (line 132) | def default_flax_embed_init(tensor):
  class SiglipVisionModelOutput (line 143) | class SiglipVisionModelOutput(ModelOutput):
  class SiglipTextModelOutput (line 162) | class SiglipTextModelOutput(ModelOutput):
  class SiglipOutput (line 177) | class SiglipOutput(ModelOutput):
    method to_tuple (line 205) | def to_tuple(self) -> tuple[Any]:
  class SiglipVisionEmbeddings (line 212) | class SiglipVisionEmbeddings(nn.Module):
    method __init__ (line 213) | def __init__(self, config: SiglipVisionConfig):
    method interpolate_pos_encoding (line 233) | def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: i...
    method forward (line 271) | def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_enc...
  class SiglipTextEmbeddings (line 285) | class SiglipTextEmbeddings(nn.Module):
    method __init__ (line 286) | def __init__(self, config: SiglipTextConfig):
    method forward (line 298) | def forward(
  function eager_attention_forward (line 325) | def eager_attention_forward(
  class SiglipAttention (line 348) | class SiglipAttention(nn.Module):
    method __init__ (line 351) | def __init__(self, config):
    method forward (line 371) | def forward(
  class SiglipMLP (line 420) | class SiglipMLP(nn.Module):
    method __init__ (line 421) | def __init__(self, config):
    method forward (line 428) | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  class SiglipEncoderLayer (line 435) | class SiglipEncoderLayer(GradientCheckpointingLayer):
    method __init__ (line 436) | def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]):
    method forward (line 444) | def forward(
  class SiglipPreTrainedModel (line 484) | class SiglipPreTrainedModel(PreTrainedModel):
    method _init_weights (line 501) | def _init_weights(self, module):
  class SiglipEncoder (line 549) | class SiglipEncoder(nn.Module):
    method __init__ (line 558) | def __init__(self, config: SiglipConfig):
    method forward (line 566) | def forward(
  class SiglipTextTransformer (line 629) | class SiglipTextTransformer(nn.Module):
    method __init__ (line 630) | def __init__(self, config: SiglipTextConfig):
    method forward (line 643) | def forward(
  class SiglipTextModel (line 697) | class SiglipTextModel(SiglipPreTrainedModel):
    method __init__ (line 700) | def __init__(self, config: SiglipTextConfig):
    method get_input_embeddings (line 706) | def get_input_embeddings(self) -> nn.Module:
    method set_input_embeddings (line 709) | def set_input_embeddings(self, value):
    method forward (line 714) | def forward(
  class SiglipVisionTransformer (line 748) | class SiglipVisionTransformer(nn.Module):
    method __init__ (line 749) | def __init__(self, config: SiglipVisionConfig):
    method forward (line 763) | def forward(
  class SiglipMultiheadAttentionPoolingHead (line 799) | class SiglipMultiheadAttentionPoolingHead(nn.Module):
    method __init__ (line 802) | def __init__(self, config: SiglipVisionConfig):
    method forward (line 810) | def forward(self, hidden_state):
  class SiglipVisionModel (line 828) | class SiglipVisionModel(SiglipPreTrainedModel):
    method __init__ (line 832) | def __init__(self, config: SiglipVisionConfig):
    method get_input_embeddings (line 840) | def get_input_embeddings(self) -> nn.Module:
    method forward (line 845) | def forward(
  class SiglipModel (line 882) | class SiglipModel(SiglipPreTrainedModel):
    method __init__ (line 885) | def __init__(self, config: SiglipConfig):
    method get_text_features (line 918) | def get_text_features(
    method get_image_features (line 964) | def get_image_features(
    method forward (line 1014) | def forward(
  class SiglipForImageClassification (line 1117) | class SiglipForImageClassification(SiglipPreTrainedModel):
    method __init__ (line 1120) | def __init__(self, config: SiglipConfig) -> None:
    method forward (line 1140) | def forward(

FILE: src/openpi/policies/aloha_policy.py
  function make_aloha_example (line 10) | def make_aloha_example() -> dict:
  class AlohaInputs (line 25) | class AlohaInputs(transforms.DataTransformFn):
    method __call__ (line 42) | def __call__(self, data: dict) -> dict:
  class AlohaOutputs (line 91) | class AlohaOutputs(transforms.DataTransformFn):
    method __call__ (line 98) | def __call__(self, data: dict) -> dict:
  function _joint_flip_mask (line 104) | def _joint_flip_mask() -> np.ndarray:
  function _normalize (line 109) | def _normalize(x, min_val, max_val):
  function _unnormalize (line 113) | def _unnormalize(x, min_val, max_val):
  function _gripper_to_angular (line 117) | def _gripper_to_angular(value):
  function _gripper_from_angular (line 140) | def _gripper_from_angular(value):
  function _gripper_from_angular_inv (line 153) | def _gripper_from_angular_inv(value):
  function _decode_aloha (line 159) | def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict:
  function _decode_state (line 181) | def _decode_state(state: np.ndarray, *, adapt_to_pi: bool = False) -> np...
  function _encode_actions (line 190) | def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -...
  function _encode_actions_inv (line 198) | def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = Fals...

FILE: src/openpi/policies/droid_policy.py
  function make_droid_example (line 10) | def make_droid_example() -> dict:
  function _parse_image (line 21) | def _parse_image(image) -> np.ndarray:
  class DroidInputs (line 31) | class DroidInputs(transforms.DataTransformFn):
    method __call__ (line 35) | def __call__(self, data: dict) -> dict:
  class DroidOutputs (line 78) | class DroidOutputs(transforms.DataTransformFn):
    method __call__ (line 79) | def __call__(self, data: dict) -> dict:

FILE: src/openpi/policies/libero_policy.py
  function make_libero_example (line 10) | def make_libero_example() -> dict:
  function _parse_image (line 20) | def _parse_image(image) -> np.ndarray:
  class LiberoInputs (line 30) | class LiberoInputs(transforms.DataTransformFn):
    method __call__ (line 42) | def __call__(self, data: dict) -> dict:
  class LiberoOutputs (line 87) | class LiberoOutputs(transforms.DataTransformFn):
    method __call__ (line 95) | def __call__(self, data: dict) -> dict:

FILE: src/openpi/policies/policy.py
  class Policy (line 24) | class Policy(BasePolicy):
    method __init__ (line 25) | def __init__(
    method infer (line 68) | def infer(self, obs: dict, *, noise: np.ndarray | None = None) -> dict...
    method metadata (line 109) | def metadata(self) -> dict[str, Any]:
  class PolicyRecorder (line 113) | class PolicyRecorder(_base_policy.BasePolicy):
    method __init__ (line 116) | def __init__(self, policy: _base_policy.BasePolicy, record_dir: str):
    method infer (line 125) | def infer(self, obs: dict) -> dict:  # type: ignore[misc]

FILE: src/openpi/policies/policy_config.py
  function create_trained_policy (line 16) | def create_trained_policy(

FILE: src/openpi/policies/policy_test.py
  function test_infer (line 10) | def test_infer():
  function test_broker (line 21) | def test_broker():

FILE: src/openpi/serving/websocket_policy_server.py
  class WebsocketPolicyServer (line 15) | class WebsocketPolicyServer:
    method __init__ (line 21) | def __init__(
    method serve_forever (line 34) | def serve_forever(self) -> None:
    method run (line 37) | async def run(self):
    method _handler (line 48) | async def _handler(self, websocket: _server.ServerConnection):
  function _health_check (line 86) | def _health_check(connection: _server.ServerConnection, request: _server...

FILE: src/openpi/shared/array_typing.py
  function _check_dataclass_annotations (line 34) | def _check_dataclass_annotations(self, typechecker):
  function typecheck (line 52) | def typecheck(t: T) -> T:
  function disable_typechecking (line 57) | def disable_typechecking():
  function check_pytree_equality (line 64) | def check_pytree_equality(*, expected: PyTree, got: PyTree, check_shapes...

FILE: src/openpi/shared/download.py
  function get_cache_dir (line 25) | def get_cache_dir() -> pathlib.Path:
  function maybe_download (line 32) | def maybe_download(url: str, *, force_download: bool = False, **kwargs) ...
  function _download_gsutil (line 108) | def _download_gsutil(url: str, local_path: pathlib.Path, **kwargs) -> None:
  function _download_fsspec (line 123) | def _download_fsspec(url: str, local_path: pathlib.Path, **kwargs) -> None:
  function _set_permission (line 142) | def _set_permission(path: pathlib.Path, target_permission: int):
  function _set_folder_permission (line 151) | def _set_folder_permission(folder_path: pathlib.Path) -> None:
  function _ensure_permissions (line 156) | def _ensure_permissions(path: pathlib.Path) -> None:
  function _get_mtime (line 189) | def _get_mtime(year: int, month: int, day: int) -> float:
  function _should_invalidate_cache (line 205) | def _should_invalidate_cache(cache_dir: pathlib.Path, local_path: pathli...

FILE: src/openpi/shared/download_test.py
  function set_openpi_data_home (line 9) | def set_openpi_data_home(tmp_path_factory):
  function test_download_local (line 16) | def test_download_local(tmp_path: pathlib.Path):
  function test_download_gs_dir (line 27) | def test_download_gs_dir():
  function test_download_gs (line 37) | def test_download_gs():
  function test_download_fsspec (line 47) | def test_download_fsspec():

FILE: src/openpi/shared/image_tools.py
  function resize_with_pad (line 13) | def resize_with_pad(
  function resize_with_pad_torch (line 55) | def resize_with_pad_torch(

FILE: src/openpi/shared/image_tools_test.py
  function test_resize_with_pad_shapes (line 6) | def test_resize_with_pad_shapes():

FILE: src/openpi/shared/nnx_utils.py
  function module_jit (line 15) | def module_jit(meth: Callable[P, R], *jit_args, **jit_kwargs) -> Callabl...
  class PathRegex (line 47) | class PathRegex:
    method __post_init__ (line 56) | def __post_init__(self):
    method __call__ (line 60) | def __call__(self, path: nnx.filterlib.PathParts, x: Any) -> bool:
  function state_map (line 66) | def state_map(state: nnx.State, filter: nnx.filterlib.Filter, fn: Callab...

FILE: src/openpi/shared/normalize.py
  class NormStats (line 10) | class NormStats:
  class RunningStats (line 17) | class RunningStats:
    method __init__ (line 20) | def __init__(self):
    method update (line 30) | def update(self, batch: np.ndarray) -> None:
    method get_statistics (line 73) | def get_statistics(self) -> NormStats:
    method _adjust_histograms (line 88) | def _adjust_histograms(self):
    method _update_histograms (line 100) | def _update_histograms(self, batch: np.ndarray) -> None:
    method _compute_quantiles (line 106) | def _compute_quantiles(self, quantiles):
  class _NormStatsDict (line 120) | class _NormStatsDict(pydantic.BaseModel):
  function serialize_json (line 124) | def serialize_json(norm_stats: dict[str, NormStats]) -> str:
  function deserialize_json (line 129) | def deserialize_json(data: str) -> dict[str, NormStats]:
  function save (line 134) | def save(directory: pathlib.Path | str, norm_stats: dict[str, NormStats]...
  function load (line 141) | def load(directory: pathlib.Path | str) -> dict[str, NormStats]:

FILE: src/openpi/shared/normalize_test.py
  function test_normalize_update (line 6) | def test_normalize_update():
  function test_serialize_deserialize (line 18) | def test_serialize_deserialize():
  function test_multiple_batch_dimensions (line 28) | def test_multiple_batch_dimensions():

FILE: src/openpi/training/checkpoints.py
  function initialize_checkpoint_dir (line 20) | def initialize_checkpoint_dir(
  function save_state (line 65) | def save_state(
  function restore_state (line 89) | def restore_state(
  function load_norm_stats (line 110) | def load_norm_stats(assets_dir: epath.Path | str, asset_id: str) -> dict...
  class Callback (line 117) | class Callback(Protocol):
    method __call__ (line 118) | def __call__(self, directory: epath.Path) -> None: ...
  class CallbackHandler (line 121) | class CallbackHandler(ocp.AsyncCheckpointHandler):
    method save (line 124) | def save(self, directory: epath.Path, args: CallbackSave):
    method async_save (line 128) | async def async_save(self, directory: epath.Path, args: CallbackSave) ...
    method restore (line 131) | def restore(self, *args, **kwargs):
  class CallbackSave (line 137) | class CallbackSave(ocp.args.CheckpointArgs):
  class CallbackRestore (line 142) | class CallbackRestore(ocp.args.CheckpointArgs): ...
  function _split_params (line 145) | def _split_params(state: training_utils.TrainState) -> tuple[training_ut...
  function _merge_params (line 155) | def _merge_params(train_state: training_utils.TrainState, params: dict[s...

FILE: src/openpi/training/config.py
  class AssetsConfig (line 38) | class AssetsConfig:
  class DataConfig (line 65) | class DataConfig:
  class GroupFactory (line 101) | class GroupFactory(Protocol):
    method __call__ (line 102) | def __call__(self, model_config: _model.BaseModelConfig) -> _transform...
  class ModelTransformFactory (line 107) | class ModelTransformFactory(GroupFactory):
    method __call__ (line 113) | def __call__(self, model_config: _model.BaseModelConfig) -> _transform...
  class DataConfigFactory (line 167) | class DataConfigFactory(abc.ABC):
    method create (line 176) | def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseM...
    method create_base_config (line 179) | def create_base_config(self, assets_dirs: pathlib.Path, model_config: ...
    method _load_norm_stats (line 190) | def _load_norm_stats(self, assets_dir: epath.Path, asset_id: str | Non...
  class FakeDataConfig (line 204) | class FakeDataConfig(DataConfigFactory):
    method create (line 208) | def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseM...
  class SimpleDataConfig (line 213) | class SimpleDataConfig(DataConfigFactory):
    method create (line 220) | def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseM...
  class LeRobotAlohaDataConfig (line 229) | class LeRobotAlohaDataConfig(DataConfigFactory):
    method create (line 258) | def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseM...
  class LeRobotLiberoDataConfig (line 282) | class LeRobotLiberoDataConfig(DataConfigFactory):
    method create (line 292) | def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseM...
  class RLDSDroidDataConfig (line 359) | class RLDSDroidDataConfig(DataConfigFactory):
    method create (line 382) | def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseM...
  class LeRobotDROIDDataConfig (line 427) | class LeRobotDROIDDataConfig(DataConfigFactory):
    method create (line 434) | def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseM...
  class TrainConfig (line 466) | class TrainConfig:
    method assets_dirs (line 538) | def assets_dirs(self) -> pathlib.Path:
    method checkpoint_dir (line 543) | def checkpoint_dir(self) -> pathlib.Path:
    method trainable_filter (line 550) | def trainable_filter(self) -> nnx.filterlib.Filter:
    method __post_init__ (line 554) | def __post_init__(self) -> None:
  function cli (line 978) | def cli() -> TrainConfig:
  function get_config (line 982) | def get_config(config_name: str) -> TrainConfig:

FILE: src/openpi/training/data_loader.py
  class Dataset (line 22) | class Dataset(Protocol[T_co]):
    method __getitem__ (line 25) | def __getitem__(self, index: SupportsIndex) -> T_co:
    method __len__ (line 28) | def __len__(self) -> int:
  class IterableDataset (line 32) | class IterableDataset(Protocol[T_co]):
    method __iter__ (line 35) | def __iter__(self) -> Iterator[T_co]:
    method __len__ (line 38) | def __len__(self) -> int:
  class DataLoader (line 42) | class DataLoader(Protocol[T_co]):
    method data_config (line 45) | def data_config(self) -> _config.DataConfig:
    method __iter__ (line 49) | def __iter__(self) -> Iterator[T_co]:
  class TransformedDataset (line 53) | class TransformedDataset(Dataset[T_co]):
    method __init__ (line 54) | def __init__(self, dataset: Dataset, transforms: Sequence[_transforms....
    method __getitem__ (line 58) | def __getitem__(self, index: SupportsIndex) -> T_co:
    method __len__ (line 61) | def __len__(self) -> int:
  class IterableTransformedDataset (line 65) | class IterableTransformedDataset(IterableDataset[T_co]):
    method __init__ (line 66) | def __init__(
    method __iter__ (line 77) | def __iter__(self):
    method __len__ (line 95) | def __len__(self) -> int:
  class FakeDataset (line 99) | class FakeDataset(Dataset):
    method __init__ (line 100) | def __init__(self, model_config: _model.BaseModelConfig, num_samples: ...
    method __getitem__ (line 104) | def __getitem__(self, index: SupportsIndex) -> dict:
    method __len__ (line 126) | def __len__(self) -> int:
  function create_torch_dataset (line 130) | def create_torch_dataset(
  function create_rlds_dataset (line 154) | def create_rlds_dataset(
  function transform_dataset (line 172) | def transform_dataset(dataset: Dataset, data_config: _config.DataConfig,...
  function transform_iterable_dataset (line 194) | def transform_iterable_dataset(
  function create_data_loader (line 223) | def create_data_loader(
  function create_torch_data_loader (line 271) | def create_torch_data_loader(
  function create_rlds_data_loader (line 340) | def create_rlds_data_loader(
  class TorchDataLoader (line 381) | class TorchDataLoader:
    method __init__ (line 384) | def __init__(
    method torch_loader (line 449) | def torch_loader(self) -> torch.utils.data.DataLoader:
    method __iter__ (line 452) | def __iter__(self):
  function _collate_fn (line 471) | def _collate_fn(items):
  function _worker_init_fn (line 478) | def _worker_init_fn(worker_id: int) -> None:
  class RLDSDataLoader (line 486) | class RLDSDataLoader:
    method __init__ (line 492) | def __init__(
    method __iter__ (line 515) | def __iter__(self):
  class DataLoaderImpl (line 530) | class DataLoaderImpl(DataLoader):
    method __init__ (line 531) | def __init__(self, data_config: _config.DataConfig, data_loader: Torch...
    method data_config (line 535) | def data_config(self) -> _config.DataConfig:
    method __iter__ (line 538) | def __iter__(self):

FILE: src/openpi/training/data_loader_test.py
  function test_torch_data_loader (line 10) | def test_torch_data_loader():
  function test_torch_data_loader_infinite (line 26) | def test_torch_data_loader_infinite():
  function test_torch_data_loader_parallel (line 37) | def test_torch_data_loader_parallel():
  function test_with_fake_dataset (line 50) | def test_with_fake_dataset():
  function test_with_real_dataset (line 65) | def test_with_real_dataset():

FILE: src/openpi/training/droid_rlds_dataset.py
  class DroidActionSpace (line 21) | class DroidActionSpace(Enum):
  class RLDSDataset (line 29) | class RLDSDataset:
  class DroidRldsDataset (line 36) | class DroidRldsDataset:
    method __init__ (line 37) | def __init__(
    method __iter__ (line 242) | def __iter__(self):
    method __len__ (line 245) | def __len__(self):

FILE: src/openpi/training/misc/polaris_config.py
  function get_polaris_configs (line 18) | def get_polaris_configs():

FILE: src/openpi/training/misc/roboarena_config.py
  function get_roboarena_configs (line 15) | def get_roboarena_configs():

FILE: src/openpi/training/optimizer.py
  class LRScheduleConfig (line 11) | class LRScheduleConfig(Protocol):
    method create (line 12) | def create(self) -> optax.Schedule: ...
  class CosineDecaySchedule (line 16) | class CosineDecaySchedule(LRScheduleConfig):
    method create (line 24) | def create(self) -> optax.Schedule:
  class RsqrtDecaySchedule (line 35) | class RsqrtDecaySchedule(LRScheduleConfig):
    method create (line 42) | def create(self) -> optax.Schedule:
  class OptimizerConfig (line 57) | class OptimizerConfig(Protocol):
    method create (line 58) | def create(
  class AdamW (line 66) | class AdamW(OptimizerConfig):
    method create (line 76) | def create(
  class SGD (line 89) | class SGD(OptimizerConfig):
    method create (line 96) | def create(
  function create_optimizer (line 105) | def create_optimizer(

FILE: src/openpi/training/sharding.py
  class _MeshState (line 13) | class _MeshState:
  function make_mesh (line 17) | def make_mesh(num_fsdp_devices: int) -> jax.sharding.Mesh:
  function set_mesh (line 27) | def set_mesh(mesh: jax.sharding.Mesh):
  function activation_sharding_constraint (line 40) | def activation_sharding_constraint(pytree):
  function fsdp_sharding (line 48) | def fsdp_sharding(

FILE: src/openpi/training/utils.py
  class TrainState (line 15) | class TrainState:
  function tree_to_info (line 27) | def tree_to_info(tree: at.PyTree, interp_func: Callable[[Any], str] = st...
  function array_tree_to_info (line 36) | def array_tree_to_info(tree: at.PyTree) -> str:

FILE: src/openpi/training/weight_loaders.py
  class WeightLoader (line 17) | class WeightLoader(Protocol):
    method load (line 18) | def load(self, params: at.Params) -> at.Params:
  class NoOpWeightLoader (line 32) | class NoOpWeightLoader(WeightLoader):
    method load (line 33) | def load(self, params: at.Params) -> at.Params:
  class CheckpointWeightLoader (line 38) | class CheckpointWeightLoader(WeightLoader):
    method load (line 50) | def load(self, params: at.Params) -> at.Params:
  class PaliGemmaWeightLoader (line 58) | class PaliGemmaWeightLoader(WeightLoader):
    method load (line 65) | def load(self, params: at.Params) -> at.Params:
  function _merge_params (line 76) | def _merge_params(loaded_params: at.Params, params: at.Params, *, missin...

FILE: src/openpi/transforms.py
  class DataTransformFn (line 24) | class DataTransformFn(Protocol):
    method __call__ (line 25) | def __call__(self, data: DataDict) -> DataDict:
  class Group (line 40) | class Group:
    method push (line 49) | def push(self, *, inputs: Sequence[DataTransformFn] = (), outputs: Seq...
  class CompositeTransform (line 63) | class CompositeTransform(DataTransformFn):
    method __call__ (line 68) | def __call__(self, data: DataDict) -> DataDict:
  function compose (line 74) | def compose(transforms: Sequence[DataTransformFn]) -> DataTransformFn:
  class RepackTransform (line 80) | class RepackTransform(DataTransformFn):
    method __call__ (line 99) | def __call__(self, data: DataDict) -> DataDict:
  class InjectDefaultPrompt (line 105) | class InjectDefaultPrompt(DataTransformFn):
    method __call__ (line 108) | def __call__(self, data: DataDict) -> DataDict:
  class Normalize (line 115) | class Normalize(DataTransformFn):
    method __post_init__ (line 122) | def __post_init__(self):
    method __call__ (line 126) | def __call__(self, data: DataDict) -> DataDict:
    method _normalize (line 137) | def _normalize(self, x, stats: NormStats):
    method _normalize_quantile (line 141) | def _normalize_quantile(self, x, stats: NormStats):
  class Unnormalize (line 149) | class Unnormalize(DataTransformFn):
    method __post_init__ (line 154) | def __post_init__(self):
    method __call__ (line 158) | def __call__(self, data: DataDict) -> DataDict:
    method _unnormalize (line 170) | def _unnormalize(self, x, stats: NormStats):
    method _unnormalize_quantile (line 175) | def _unnormalize_quantile(self, x, stats: NormStats):
  class ResizeImages (line 185) | class ResizeImages(DataTransformFn):
    method __call__ (line 189) | def __call__(self, data: DataDict) -> DataDict:
  class SubsampleActions (line 195) | class SubsampleActions(DataTransformFn):
    method __call__ (line 198) | def __call__(self, data: DataDict) -> DataDict:
  class DeltaActions (line 204) | class DeltaActions(DataTransformFn):
    method __call__ (line 212) | def __call__(self, data: DataDict) -> DataDict:
  class AbsoluteActions (line 226) | class AbsoluteActions(DataTransformFn):
    method __call__ (line 234) | def __call__(self, data: DataDict) -> DataDict:
  class TokenizePrompt (line 248) | class TokenizePrompt(DataTransformFn):
    method __call__ (line 252) | def __call__(self, data: DataDict) -> DataDict:
  class TokenizeFASTInputs (line 270) | class TokenizeFASTInputs(DataTransformFn):
    method __call__ (line 273) | def __call__(self, data: DataDict) -> DataDict:
  class ExtractFASTActions (line 292) | class ExtractFASTActions(DataTransformFn):
    method __call__ (line 297) | def __call__(self, data: DataDict) -> DataDict:
  class PromptFromLeRobotTask (line 310) | class PromptFromLeRobotTask(DataTransformFn):
    method __call__ (line 316) | def __call__(self, data: DataDict) -> DataDict:
  class PadStatesAndActions (line 328) | class PadStatesAndActions(DataTransformFn):
    method __call__ (line 333) | def __call__(self, data: DataDict) -> DataDict:
  function flatten_dict (line 340) | def flatten_dict(tree: at.PyTree) -> dict:
  function unflatten_dict (line 345) | def unflatten_dict(tree: dict) -> at.PyTree:
  function transform_dict (line 350) | def transform_dict(patterns: Mapping[str, str | None], tree: at.PyTree) ...
  function apply_tree (line 404) | def apply_tree(
  function pad_to_dim (line 423) | def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: fl...
  function make_bool_mask (line 433) | def make_bool_mask(*dims: int) -> tuple[bool, ...]:
  function _assert_quantile_stats (line 455) | def _assert_quantile_stats(norm_stats: at.PyTree[NormStats]) -> None:

FILE: src/openpi/transforms_test.py
  function test_repack_transform (line 8) | def test_repack_transform():
  function test_delta_actions (line 19) | def test_delta_actions():
  function test_delta_actions_noop (line 29) | def test_delta_actions_noop():
  function test_absolute_actions (line 42) | def test_absolute_actions():
  function test_absolute_actions_noop (line 52) | def test_absolute_actions_noop():
  function test_make_bool_mask (line 65) | def test_make_bool_mask():
  function test_tokenize_prompt (line 70) | def test_tokenize_prompt():
  function test_tokenize_no_prompt (line 81) | def test_tokenize_no_prompt():
  function test_transform_dict (line 88) | def test_transform_dict():
  function test_extract_prompt_from_task (line 114) | def test_extract_prompt_from_task():
Condensed preview — 139 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (846K chars).
[
  {
    "path": ".dockerignore",
    "chars": 23,
    "preview": ".venv\ncheckpoints\ndata\n"
  },
  {
    "path": ".github/CODEOWNERS",
    "chars": 632,
    "preview": "# The CODEOWNERS file defines individuals or teams that are automatically requested for\n# review when someone opens a pu"
  },
  {
    "path": ".github/workflows/pre-commit.yml",
    "chars": 308,
    "preview": "name: pre-commit\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - \"*\"\njobs:\n  pre-commit:\n  "
  },
  {
    "path": ".github/workflows/test.yml",
    "chars": 669,
    "preview": "name: Test\non:\n  pull_request:\n    branches:\n      - \"*\"\n\njobs:\n  run_tests:\n    name: Run Tests\n    runs-on: openpi-ver"
  },
  {
    "path": ".gitignore",
    "chars": 3194,
    "preview": "# Data directories.\nassets/\ncheckpoints/\ndata/\nwandb/\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$"
  },
  {
    "path": ".gitmodules",
    "chars": 237,
    "preview": "[submodule \"third_party/aloha\"]\n\tpath = third_party/aloha\n\turl = https://github.com/Physical-Intelligence/aloha.git\n[sub"
  },
  {
    "path": ".pre-commit-config.yaml",
    "chars": 336,
    "preview": "exclude: third_party/\n\nrepos:\n  - repo: https://github.com/astral-sh/uv-pre-commit\n    # uv version.\n    rev: 0.5.14\n   "
  },
  {
    "path": ".python-version",
    "chars": 4,
    "preview": "3.11"
  },
  {
    "path": ".vscode/settings.json",
    "chars": 264,
    "preview": "{\n    \"[python]\": {\n        \"editor.defaultFormatter\": \"charliermarsh.ruff\",\n        \"editor.formatOnSave\": true,\n    },"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 3364,
    "preview": "# Contributing to openpi\n\nWe welcome contributions, improvements, and modifications. Everyone is welcome to use openpi i"
  },
  {
    "path": "LICENSE",
    "chars": 11356,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "LICENSE_GEMMA.txt",
    "chars": 8431,
    "preview": "Gemma Terms of Use \n\nLast modified: February 21, 2024\n\nBy using, reproducing, modifying, distributing, performing or dis"
  },
  {
    "path": "README.md",
    "chars": 23207,
    "preview": "# openpi\n\nopenpi holds open-source models and packages for robotics, published by the [Physical Intelligence team](https"
  },
  {
    "path": "docs/docker.md",
    "chars": 1989,
    "preview": "### Docker Setup\n\nAll of the examples in this repo provide instructions for being run normally, and also using Docker. A"
  },
  {
    "path": "docs/norm_stats.md",
    "chars": 5245,
    "preview": "# Normalization statistics\n\nFollowing common practice, our models normalize the proprioceptive state inputs and action t"
  },
  {
    "path": "docs/remote_inference.md",
    "chars": 3601,
    "preview": "\n# Running openpi models remotely\n\nWe provide utilities for running openpi models remotely. This is useful for running i"
  },
  {
    "path": "examples/aloha_real/Dockerfile",
    "chars": 2650,
    "preview": "# Dockerfile for the Aloha real environment.\n\n# Build the container:\n# docker build . -t aloha_real -f examples/aloha_re"
  },
  {
    "path": "examples/aloha_real/README.md",
    "chars": 6426,
    "preview": "# Run Aloha (Real Robot)\n\nThis example demonstrates how to run with a real robot using an [ALOHA setup](https://github.c"
  },
  {
    "path": "examples/aloha_real/compose.yml",
    "chars": 1440,
    "preview": "# Run with:\n# docker compose -f examples/aloha_real/compose.yml up --build\nservices:\n  runtime:\n    image: aloha_real\n  "
  },
  {
    "path": "examples/aloha_real/constants.py",
    "chars": 3296,
    "preview": "# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).\n# ruff: noqa\n\n### "
  },
  {
    "path": "examples/aloha_real/convert_aloha_data_to_lerobot.py",
    "chars": 7515,
    "preview": "\"\"\"\nScript to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.\n\nExample usage: uv run examples/aloha_real/con"
  },
  {
    "path": "examples/aloha_real/env.py",
    "chars": 1741,
    "preview": "from typing import List, Optional  # noqa: UP035\n\nimport einops\nfrom openpi_client import image_tools\nfrom openpi_client"
  },
  {
    "path": "examples/aloha_real/main.py",
    "chars": 1408,
    "preview": "import dataclasses\nimport logging\n\nfrom openpi_client import action_chunk_broker\nfrom openpi_client import websocket_cli"
  },
  {
    "path": "examples/aloha_real/real_env.py",
    "chars": 8650,
    "preview": "# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).\n# ruff: noqa\nimpor"
  },
  {
    "path": "examples/aloha_real/requirements.in",
    "chars": 183,
    "preview": "Pillow\ndm_control\neinops\nh5py\nmatplotlib\nmodern_robotics\nmsgpack\nnumpy>=1.22.4,<2.0.0\nopencv-python\npackaging\npexpect\npy"
  },
  {
    "path": "examples/aloha_real/requirements.txt",
    "chars": 3337,
    "preview": "# This file was autogenerated by uv via the following command:\n#    uv pip compile examples/aloha_real/requirements.in -"
  },
  {
    "path": "examples/aloha_real/robot_utils.py",
    "chars": 9867,
    "preview": "# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).\n# ruff: noqa\nfrom "
  },
  {
    "path": "examples/aloha_real/video_display.py",
    "chars": 958,
    "preview": "import matplotlib.pyplot as plt\nimport numpy as np\nfrom openpi_client.runtime import subscriber as _subscriber\nfrom typi"
  },
  {
    "path": "examples/aloha_sim/Dockerfile",
    "chars": 1488,
    "preview": "# Dockerfile for the Aloha simulation environment.\n\n# Build the container:\n# docker build . -t aloha_sim -f examples/alo"
  },
  {
    "path": "examples/aloha_sim/README.md",
    "chars": 732,
    "preview": "# Run Aloha Sim\n\n## With Docker\n\n```bash\nexport SERVER_ARGS=\"--env ALOHA_SIM\"\ndocker compose -f examples/aloha_sim/compo"
  },
  {
    "path": "examples/aloha_sim/compose.yml",
    "chars": 963,
    "preview": "# Run with:\n# docker compose -f examples/aloha_sim/compose.yml up --build\nservices:\n  runtime:\n    image: aloha_sim\n    "
  },
  {
    "path": "examples/aloha_sim/env.py",
    "chars": 1942,
    "preview": "import gym_aloha  # noqa: F401\nimport gymnasium\nimport numpy as np\nfrom openpi_client import image_tools\nfrom openpi_cli"
  },
  {
    "path": "examples/aloha_sim/main.py",
    "chars": 1370,
    "preview": "import dataclasses\nimport logging\nimport pathlib\n\nimport env as _env\nfrom openpi_client import action_chunk_broker\nfrom "
  },
  {
    "path": "examples/aloha_sim/requirements.in",
    "chars": 91,
    "preview": "gym-aloha\nimageio\nmatplotlib\nmsgpack\nnumpy>=1.22.4,<2.0.0\ntyping-extensions\ntyro\nwebsockets"
  },
  {
    "path": "examples/aloha_sim/requirements.txt",
    "chars": 2629,
    "preview": "# This file was autogenerated by uv via the following command:\n#    uv pip compile examples/aloha_sim/requirements.in -o"
  },
  {
    "path": "examples/aloha_sim/saver.py",
    "chars": 1270,
    "preview": "import logging\nimport pathlib\n\nimport imageio\nimport numpy as np\nfrom openpi_client.runtime import subscriber as _subscr"
  },
  {
    "path": "examples/convert_jax_model_to_pytorch.py",
    "chars": 27520,
    "preview": "#!/usr/bin/env python3\n\"\"\"\nLoad a JAX model and print all parameter keys, with optional conversion to PyTorch.\n\nThis scr"
  },
  {
    "path": "examples/droid/README.md",
    "chars": 6870,
    "preview": "# DROID Policies in openpi\n\nWe offer instructions for:\n- [Running inference for our best $pi_{0.5}$-DROID policy](./READ"
  },
  {
    "path": "examples/droid/README_train.md",
    "chars": 7329,
    "preview": "# Training on DROID\n\nHere we describe how to fine-tune the pi0.5 model on the *full* DROID dataset. This is an approxima"
  },
  {
    "path": "examples/droid/compute_droid_nonidle_ranges.py",
    "chars": 4872,
    "preview": "\"\"\"\nIterates through the DROID dataset and creates a json mapping from episode unique IDs to ranges of time steps\nthat s"
  },
  {
    "path": "examples/droid/convert_droid_data_to_lerobot.py",
    "chars": 17080,
    "preview": "\"\"\"\nMinimal example script for converting a dataset collected on the DROID platform to LeRobot format.\n\nUsage:\nuv run ex"
  },
  {
    "path": "examples/droid/main.py",
    "chars": 9825,
    "preview": "# ruff: noqa\n\nimport contextlib\nimport dataclasses\nimport datetime\nimport faulthandler\nimport os\nimport signal\nimport ti"
  },
  {
    "path": "examples/inference.ipynb",
    "chars": 5419,
    "preview": "{\n    \"cells\": [\n        {\n            \"cell_type\": \"code\",\n            \"execution_count\": 1,\n            \"metadata\": {}"
  },
  {
    "path": "examples/libero/Dockerfile",
    "chars": 2615,
    "preview": "# Dockerfile for the LIBERO benchmark.\n\n# Build the container:\n# docker build . -t libero -f examples/libero/Dockerfile\n"
  },
  {
    "path": "examples/libero/README.md",
    "chars": 2550,
    "preview": "# LIBERO Benchmark\n\nThis example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO\n\nNote: Whe"
  },
  {
    "path": "examples/libero/compose.yml",
    "chars": 1349,
    "preview": "# Run with:\n# docker compose -f examples/libero/compose.yml up --build\nservices:\n  runtime:\n    image: libero\n    depend"
  },
  {
    "path": "examples/libero/convert_libero_data_to_lerobot.py",
    "chars": 3766,
    "preview": "\"\"\"\nMinimal example script for converting a dataset to LeRobot format.\n\nWe use the Libero dataset (stored in RLDS) for t"
  },
  {
    "path": "examples/libero/main.py",
    "chars": 9194,
    "preview": "import collections\nimport dataclasses\nimport logging\nimport math\nimport pathlib\n\nimport imageio\nfrom libero.libero impor"
  },
  {
    "path": "examples/libero/requirements.in",
    "chars": 177,
    "preview": "imageio[ffmpeg]\nnumpy==1.22.4\ntqdm\ntyro\nPyYaml\nopencv-python==4.6.0.66\ntorch==1.11.0+cu113\ntorchvision==0.12.0+cu113\ntor"
  },
  {
    "path": "examples/libero/requirements.txt",
    "chars": 2842,
    "preview": "# This file was autogenerated by uv via the following command:\n#    uv pip compile examples/libero/requirements.in -o ex"
  },
  {
    "path": "examples/policy_records.ipynb",
    "chars": 3395,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": "
  },
  {
    "path": "examples/simple_client/Dockerfile",
    "chars": 1256,
    "preview": "# Dockerfile for the simple client.\n\n# Build the container:\n# docker build . -t simple_client -f examples/simple_client/"
  },
  {
    "path": "examples/simple_client/README.md",
    "chars": 589,
    "preview": "# Simple Client\n\nA minimal client that sends observations to the server and prints the inference rate.\n\nYou can specify "
  },
  {
    "path": "examples/simple_client/compose.yml",
    "chars": 968,
    "preview": "# Run with:\n# docker compose -f examples/simple_client/compose.yml up --build\nservices:\n  runtime:\n    image: simple_cli"
  },
  {
    "path": "examples/simple_client/main.py",
    "chars": 6281,
    "preview": "import dataclasses\nimport enum\nimport logging\nimport pathlib\nimport time\n\nimport numpy as np\nfrom openpi_client import w"
  },
  {
    "path": "examples/simple_client/requirements.in",
    "chars": 42,
    "preview": "numpy>=1.22.4,<2.0.0\nrich\ntqdm\ntyro\npolars"
  },
  {
    "path": "examples/simple_client/requirements.txt",
    "chars": 810,
    "preview": "# This file was autogenerated by uv via the following command:\n#    uv pip compile examples/simple_client/requirements.i"
  },
  {
    "path": "examples/ur5/README.md",
    "chars": 5846,
    "preview": "# UR5 Example\n\nBelow we provide an outline of how to implement the key components mentioned in the \"Finetune on your dat"
  },
  {
    "path": "packages/openpi-client/pyproject.toml",
    "chars": 410,
    "preview": "[project]\nname = \"openpi-client\"\nversion = \"0.1.0\"\nrequires-python = \">=3.7\"\ndependencies = [\n    \"dm-tree>=0.1.8\",\n    "
  },
  {
    "path": "packages/openpi-client/src/openpi_client/__init__.py",
    "chars": 22,
    "preview": "__version__ = \"0.1.0\"\n"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/action_chunk_broker.py",
    "chars": 1404,
    "preview": "from typing import Dict\n\nimport numpy as np\nimport tree\nfrom typing_extensions import override\n\nfrom openpi_client impor"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/base_policy.py",
    "chars": 271,
    "preview": "import abc\nfrom typing import Dict\n\n\nclass BasePolicy(abc.ABC):\n    @abc.abstractmethod\n    def infer(self, obs: Dict) -"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/image_tools.py",
    "chars": 2426,
    "preview": "import numpy as np\nfrom PIL import Image\n\n\ndef convert_to_uint8(img: np.ndarray) -> np.ndarray:\n    \"\"\"Converts an image"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/image_tools_test.py",
    "chars": 1411,
    "preview": "import numpy as np\n\nimport openpi_client.image_tools as image_tools\n\n\ndef test_resize_with_pad_shapes():\n    # Test case"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/msgpack_numpy.py",
    "chars": 1944,
    "preview": "\"\"\"Adds NumPy array support to msgpack.\n\nmsgpack is good for (de)serializing data over a network for multiple reasons:\n-"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/msgpack_numpy_test.py",
    "chars": 1643,
    "preview": "import numpy as np\nimport pytest\nimport tree\n\nfrom openpi_client import msgpack_numpy\n\n\ndef _check(expected, actual):\n  "
  },
  {
    "path": "packages/openpi-client/src/openpi_client/runtime/agent.py",
    "chars": 465,
    "preview": "import abc\n\n\nclass Agent(abc.ABC):\n    \"\"\"An Agent is the thing with agency, i.e. the entity that makes decisions.\n\n    "
  },
  {
    "path": "packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py",
    "chars": 509,
    "preview": "from typing_extensions import override\n\nfrom openpi_client import base_policy as _base_policy\nfrom openpi_client.runtime"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/runtime/environment.py",
    "chars": 1044,
    "preview": "import abc\n\n\nclass Environment(abc.ABC):\n    \"\"\"An Environment represents the robot and the environment it inhabits.\n\n  "
  },
  {
    "path": "packages/openpi-client/src/openpi_client/runtime/runtime.py",
    "chars": 3049,
    "preview": "import logging\nimport threading\nimport time\n\nfrom openpi_client.runtime import agent as _agent\nfrom openpi_client.runtim"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/runtime/subscriber.py",
    "chars": 500,
    "preview": "import abc\n\n\nclass Subscriber(abc.ABC):\n    \"\"\"Subscribes to events in the runtime.\n\n    Subscribers can be used to save"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/websocket_client_policy.py",
    "chars": 2131,
    "preview": "import logging\nimport time\nfrom typing import Dict, Optional, Tuple\n\nfrom typing_extensions import override\nimport webso"
  },
  {
    "path": "pyproject.toml",
    "chars": 2958,
    "preview": "[project]\nname = \"openpi\"\nversion = \"0.1.0\"\ndescription = \"Physical Intelligence open source repo\"\nreadme = \"README.md\"\n"
  },
  {
    "path": "scripts/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "scripts/compute_norm_stats.py",
    "chars": 4042,
    "preview": "\"\"\"Compute normalization statistics for a config.\n\nThis script is used to compute the normalization statistics for a giv"
  },
  {
    "path": "scripts/docker/compose.yml",
    "chars": 814,
    "preview": "# Run with:\n# docker compose -f scripts/docker/compose.yml up --build\nservices:\n  openpi_server:\n    image: openpi_serve"
  },
  {
    "path": "scripts/docker/install_docker_ubuntu22.sh",
    "chars": 1529,
    "preview": "#!/bin/bash\n\n# Add Docker's official GPG key:\nsudo apt-get update\nsudo apt-get install -y ca-certificates curl\nsudo inst"
  },
  {
    "path": "scripts/docker/install_nvidia_container_toolkit.sh",
    "chars": 986,
    "preview": "#!/bin/bash\n\n# Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs.\n# NVIDIA's o"
  },
  {
    "path": "scripts/docker/serve_policy.Dockerfile",
    "chars": 1916,
    "preview": "# Dockerfile for serving a PI policy.\n# Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/"
  },
  {
    "path": "scripts/serve_policy.py",
    "chars": 3687,
    "preview": "import dataclasses\nimport enum\nimport logging\nimport socket\n\nimport tyro\n\nfrom openpi.policies import policy as _policy\n"
  },
  {
    "path": "scripts/train.py",
    "chars": 10488,
    "preview": "import dataclasses\nimport functools\nimport logging\nimport platform\nfrom typing import Any\n\nimport etils.epath as epath\ni"
  },
  {
    "path": "scripts/train_pytorch.py",
    "chars": 25938,
    "preview": "\"\"\"\nPyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support.\nThis script mirrors the behavi"
  },
  {
    "path": "scripts/train_test.py",
    "chars": 718,
    "preview": "import dataclasses\nimport os\nimport pathlib\n\nimport pytest\n\nos.environ[\"JAX_PLATFORMS\"] = \"cpu\"\n\nfrom openpi.training im"
  },
  {
    "path": "src/openpi/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/openpi/conftest.py",
    "chars": 339,
    "preview": "import os\n\nimport pynvml\nimport pytest\n\n\ndef set_jax_cpu_backend_if_no_gpu() -> None:\n    try:\n        pynvml.nvmlInit()"
  },
  {
    "path": "src/openpi/models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/openpi/models/gemma.py",
    "chars": 17146,
    "preview": "# Copyright 2024 Big Vision Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not u"
  },
  {
    "path": "src/openpi/models/gemma_fast.py",
    "chars": 15643,
    "preview": "# Copyright 2024 Big Vision Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not u"
  },
  {
    "path": "src/openpi/models/lora.py",
    "chars": 5342,
    "preview": "import math\nimport re\n\nimport flax.linen as nn\nimport flax.struct as struct\nimport jax.numpy as jnp\n\nimport openpi.share"
  },
  {
    "path": "src/openpi/models/lora_test.py",
    "chars": 3138,
    "preview": "import flax.linen as nn\nimport jax\nimport jax.numpy as jnp\n\nimport openpi.models.lora as lora\n\n\ndef test_lora_einsum_par"
  },
  {
    "path": "src/openpi/models/model.py",
    "chars": 12270,
    "preview": "import abc\nfrom collections.abc import Sequence\nimport dataclasses\nimport enum\nimport logging\nimport pathlib\nfrom typing"
  },
  {
    "path": "src/openpi/models/model_test.py",
    "chars": 2966,
    "preview": "from flax import nnx\nimport jax\nimport pytest\n\nfrom openpi.models import model as _model\nfrom openpi.models import pi0_c"
  },
  {
    "path": "src/openpi/models/pi0.py",
    "chars": 12887,
    "preview": "import logging\n\nimport einops\nimport flax.nnx as nnx\nimport flax.nnx.bridge as nnx_bridge\nimport jax\nimport jax.numpy as"
  },
  {
    "path": "src/openpi/models/pi0_config.py",
    "chars": 4435,
    "preview": "import dataclasses\nfrom typing import TYPE_CHECKING\n\nimport flax.nnx as nnx\nimport jax\nimport jax.numpy as jnp\nfrom typi"
  },
  {
    "path": "src/openpi/models/pi0_fast.py",
    "chars": 13209,
    "preview": "import dataclasses\nimport logging\nfrom typing import Any\n\nimport einops\nimport flax.nnx as nnx\nimport flax.nnx.bridge as"
  },
  {
    "path": "src/openpi/models/pi0_test.py",
    "chars": 1631,
    "preview": "import flax.nnx as nnx\nimport jax\n\nimport openpi.models.pi0_config as _pi0_config\n\n\ndef _get_frozen_state(config: _pi0_c"
  },
  {
    "path": "src/openpi/models/siglip.py",
    "chars": 12078,
    "preview": "# Copyright 2024 Big Vision Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not u"
  },
  {
    "path": "src/openpi/models/tokenizer.py",
    "chars": 17517,
    "preview": "import logging\nimport os\n\nimport jax\nimport numpy as np\nimport orbax.checkpoint as ocp\nimport sentencepiece\nfrom transfo"
  },
  {
    "path": "src/openpi/models/tokenizer_test.py",
    "chars": 807,
    "preview": "import numpy as np\n\nfrom openpi.models import tokenizer as _tokenizer\n\n\ndef test_tokenize():\n    tokenizer = _tokenizer."
  },
  {
    "path": "src/openpi/models/utils/fsq_tokenizer.py",
    "chars": 15928,
    "preview": "import math\nfrom typing import Any, Literal\n\nimport chex\nfrom einops import einops\nfrom flax import linen as nn\nfrom fla"
  },
  {
    "path": "src/openpi/models/vit.py",
    "chars": 10266,
    "preview": "# Copyright 2024 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this "
  },
  {
    "path": "src/openpi/models_pytorch/gemma_pytorch.py",
    "chars": 13509,
    "preview": "from typing import Literal\n\nimport pytest\nimport torch\nfrom torch import nn\nfrom transformers import GemmaForCausalLM\nfr"
  },
  {
    "path": "src/openpi/models_pytorch/pi0_pytorch.py",
    "chars": 19128,
    "preview": "import logging\nimport math\n\nimport torch\nfrom torch import Tensor\nfrom torch import nn\nimport torch.nn.functional as F  "
  },
  {
    "path": "src/openpi/models_pytorch/preprocessing_pytorch.py",
    "chars": 7264,
    "preview": "from collections.abc import Sequence\nimport logging\n\nimport torch\n\nfrom openpi.shared import image_tools\n\nlogger = loggi"
  },
  {
    "path": "src/openpi/models_pytorch/transformers_replace/models/gemma/configuration_gemma.py",
    "chars": 8670,
    "preview": "#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n#           This file was automatically generated from"
  },
  {
    "path": "src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py",
    "chars": 36104,
    "preview": "#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n#           This file was automatically generated from"
  },
  {
    "path": "src/openpi/models_pytorch/transformers_replace/models/paligemma/modeling_paligemma.py",
    "chars": 27688,
    "preview": "# coding=utf-8\n# Copyright 2024 the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, V"
  },
  {
    "path": "src/openpi/models_pytorch/transformers_replace/models/siglip/check.py",
    "chars": 133,
    "preview": "import transformers\n\ndef check_whether_transformers_replace_is_installed_correctly():\n    return transformers.__version_"
  },
  {
    "path": "src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py",
    "chars": 50231,
    "preview": "# coding=utf-8\n# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache L"
  },
  {
    "path": "src/openpi/policies/aloha_policy.py",
    "chars": 7543,
    "preview": "import dataclasses\nfrom typing import ClassVar\n\nimport einops\nimport numpy as np\n\nfrom openpi import transforms\n\n\ndef ma"
  },
  {
    "path": "src/openpi/policies/droid_policy.py",
    "chars": 3186,
    "preview": "import dataclasses\n\nimport einops\nimport numpy as np\n\nfrom openpi import transforms\nfrom openpi.models import model as _"
  },
  {
    "path": "src/openpi/policies/libero_policy.py",
    "chars": 4320,
    "preview": "import dataclasses\n\nimport einops\nimport numpy as np\n\nfrom openpi import transforms\nfrom openpi.models import model as _"
  },
  {
    "path": "src/openpi/policies/policy.py",
    "chars": 5344,
    "preview": "from collections.abc import Sequence\nimport logging\nimport pathlib\nimport time\nfrom typing import Any, TypeAlias\n\nimport"
  },
  {
    "path": "src/openpi/policies/policy_config.py",
    "chars": 4107,
    "preview": "import logging\nimport os\nimport pathlib\nfrom typing import Any\n\nimport jax.numpy as jnp\n\nimport openpi.models.model as _"
  },
  {
    "path": "src/openpi/policies/policy_test.py",
    "chars": 1127,
    "preview": "from openpi_client import action_chunk_broker\nimport pytest\n\nfrom openpi.policies import aloha_policy\nfrom openpi.polici"
  },
  {
    "path": "src/openpi/py.typed",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/openpi/serving/websocket_policy_server.py",
    "chars": 3051,
    "preview": "import asyncio\nimport http\nimport logging\nimport time\nimport traceback\n\nfrom openpi_client import base_policy as _base_p"
  },
  {
    "path": "src/openpi/shared/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/openpi/shared/array_typing.py",
    "chars": 3402,
    "preview": "import contextlib\nimport functools as ft\nimport inspect\nfrom typing import TypeAlias, TypeVar, cast\n\nimport beartype\nimp"
  },
  {
    "path": "src/openpi/shared/download.py",
    "chars": 8862,
    "preview": "import concurrent.futures\nimport datetime\nimport logging\nimport os\nimport pathlib\nimport re\nimport shutil\nimport stat\nim"
  },
  {
    "path": "src/openpi/shared/download_test.py",
    "chars": 1490,
    "preview": "import pathlib\n\nimport pytest\n\nimport openpi.shared.download as download\n\n\n@pytest.fixture(scope=\"session\", autouse=True"
  },
  {
    "path": "src/openpi/shared/image_tools.py",
    "chars": 4665,
    "preview": "import functools\n\nimport jax\nimport jax.numpy as jnp\nimport torch\nimport torch.nn.functional as F  # noqa: N812\n\nimport "
  },
  {
    "path": "src/openpi/shared/image_tools_test.py",
    "chars": 1418,
    "preview": "import jax.numpy as jnp\n\nfrom openpi.shared import image_tools\n\n\ndef test_resize_with_pad_shapes():\n    # Test case 1: R"
  },
  {
    "path": "src/openpi/shared/nnx_utils.py",
    "chars": 2844,
    "preview": "from collections.abc import Callable\nimport dataclasses\nimport functools\nimport inspect\nimport re\nfrom typing import Any"
  },
  {
    "path": "src/openpi/shared/normalize.py",
    "chars": 5529,
    "preview": "import json\nimport pathlib\n\nimport numpy as np\nimport numpydantic\nimport pydantic\n\n\n@pydantic.dataclasses.dataclass\nclas"
  },
  {
    "path": "src/openpi/shared/normalize_test.py",
    "chars": 1521,
    "preview": "import numpy as np\n\nimport openpi.shared.normalize as normalize\n\n\ndef test_normalize_update():\n    arr = np.arange(12).r"
  },
  {
    "path": "src/openpi/training/checkpoints.py",
    "chars": 5763,
    "preview": "from __future__ import annotations\n\nimport asyncio\nimport concurrent.futures as futures\nimport dataclasses\nimport loggin"
  },
  {
    "path": "src/openpi/training/config.py",
    "chars": 44594,
    "preview": "\"\"\"See _CONFIGS for the list of available configs.\"\"\"\n\nimport abc\nfrom collections.abc import Sequence\nimport dataclasse"
  },
  {
    "path": "src/openpi/training/data_loader.py",
    "chars": 19977,
    "preview": "from collections.abc import Iterator, Sequence\nimport logging\nimport multiprocessing\nimport os\nimport typing\nfrom typing"
  },
  {
    "path": "src/openpi/training/data_loader_test.py",
    "chars": 2487,
    "preview": "import dataclasses\n\nimport jax\n\nfrom openpi.models import pi0_config\nfrom openpi.training import config as _config\nfrom "
  },
  {
    "path": "src/openpi/training/droid_rlds_dataset.py",
    "chars": 10886,
    "preview": "\"\"\"\nRLDS-based data loader for DROID.\nWhile openpi typically uses LeRobot's data loader, it is not currently scalable en"
  },
  {
    "path": "src/openpi/training/misc/polaris_config.py",
    "chars": 9624,
    "preview": "\"\"\"PolaRiS baseline policy configs.\"\"\"\n\nfrom typing import TypeAlias\n\nimport openpi.models.model as _model\nimport openpi"
  },
  {
    "path": "src/openpi/training/misc/roboarena_config.py",
    "chars": 4804,
    "preview": "\"\"\"RoboArena baseline policy configs.\"\"\"\n\nfrom typing import TypeAlias\n\nimport openpi.models.model as _model\nimport open"
  },
  {
    "path": "src/openpi/training/optimizer.py",
    "chars": 3202,
    "preview": "import dataclasses\nfrom typing import Protocol, runtime_checkable\n\nimport jax.numpy as jnp\nimport optax\n\nimport openpi.s"
  },
  {
    "path": "src/openpi/training/sharding.py",
    "chars": 4099,
    "preview": "import contextlib\nimport logging\n\nimport jax\nimport numpy as np\n\nBATCH_AXIS = \"batch\"\nFSDP_AXIS = \"fsdp\"\n# In FSDP, we s"
  },
  {
    "path": "src/openpi/training/utils.py",
    "chars": 1217,
    "preview": "from collections.abc import Callable\nfrom typing import Any\n\nfrom flax import nnx\nfrom flax import struct\nimport jax\nimp"
  },
  {
    "path": "src/openpi/training/weight_loaders.py",
    "chars": 3738,
    "preview": "import dataclasses\nimport logging\nimport re\nfrom typing import Protocol, runtime_checkable\n\nimport flax.traverse_util\nim"
  },
  {
    "path": "src/openpi/transforms.py",
    "chars": 15752,
    "preview": "from collections.abc import Callable, Mapping, Sequence\nimport dataclasses\nimport re\nfrom typing import Protocol, TypeAl"
  },
  {
    "path": "src/openpi/transforms_test.py",
    "chars": 4111,
    "preview": "import numpy as np\nimport pytest\n\nimport openpi.models.tokenizer as _tokenizer\nimport openpi.transforms as _transforms\n\n"
  }
]

About this extraction

This page contains the full source code of the Physical-Intelligence/openpi GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 139 files (789.8 KB), approximately 196.6k tokens, and a symbol index with 837 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!