Repository: Physical-Intelligence/openpi Branch: main Commit: 54cbaee6ae0c Files: 139 Total size: 789.8 KB Directory structure: gitextract_6rhyljlr/ ├── .dockerignore ├── .github/ │ ├── CODEOWNERS │ └── workflows/ │ ├── pre-commit.yml │ └── test.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── .python-version ├── .vscode/ │ └── settings.json ├── CONTRIBUTING.md ├── LICENSE ├── LICENSE_GEMMA.txt ├── README.md ├── docs/ │ ├── docker.md │ ├── norm_stats.md │ └── remote_inference.md ├── examples/ │ ├── aloha_real/ │ │ ├── Dockerfile │ │ ├── README.md │ │ ├── compose.yml │ │ ├── constants.py │ │ ├── convert_aloha_data_to_lerobot.py │ │ ├── env.py │ │ ├── main.py │ │ ├── real_env.py │ │ ├── requirements.in │ │ ├── requirements.txt │ │ ├── robot_utils.py │ │ └── video_display.py │ ├── aloha_sim/ │ │ ├── Dockerfile │ │ ├── README.md │ │ ├── compose.yml │ │ ├── env.py │ │ ├── main.py │ │ ├── requirements.in │ │ ├── requirements.txt │ │ └── saver.py │ ├── convert_jax_model_to_pytorch.py │ ├── droid/ │ │ ├── README.md │ │ ├── README_train.md │ │ ├── compute_droid_nonidle_ranges.py │ │ ├── convert_droid_data_to_lerobot.py │ │ └── main.py │ ├── inference.ipynb │ ├── libero/ │ │ ├── Dockerfile │ │ ├── README.md │ │ ├── compose.yml │ │ ├── convert_libero_data_to_lerobot.py │ │ ├── main.py │ │ ├── requirements.in │ │ └── requirements.txt │ ├── policy_records.ipynb │ ├── simple_client/ │ │ ├── Dockerfile │ │ ├── README.md │ │ ├── compose.yml │ │ ├── main.py │ │ ├── requirements.in │ │ └── requirements.txt │ └── ur5/ │ └── README.md ├── packages/ │ └── openpi-client/ │ ├── pyproject.toml │ └── src/ │ └── openpi_client/ │ ├── __init__.py │ ├── action_chunk_broker.py │ ├── base_policy.py │ ├── image_tools.py │ ├── image_tools_test.py │ ├── msgpack_numpy.py │ ├── msgpack_numpy_test.py │ ├── runtime/ │ │ ├── agent.py │ │ ├── agents/ │ │ │ └── policy_agent.py │ │ ├── environment.py │ │ ├── runtime.py │ │ └── subscriber.py │ └── websocket_client_policy.py ├── pyproject.toml ├── scripts/ │ ├── __init__.py │ ├── compute_norm_stats.py │ ├── docker/ │ │ ├── compose.yml │ │ ├── install_docker_ubuntu22.sh │ │ ├── install_nvidia_container_toolkit.sh │ │ └── serve_policy.Dockerfile │ ├── serve_policy.py │ ├── train.py │ ├── train_pytorch.py │ └── train_test.py └── src/ └── openpi/ ├── __init__.py ├── conftest.py ├── models/ │ ├── __init__.py │ ├── gemma.py │ ├── gemma_fast.py │ ├── lora.py │ ├── lora_test.py │ ├── model.py │ ├── model_test.py │ ├── pi0.py │ ├── pi0_config.py │ ├── pi0_fast.py │ ├── pi0_test.py │ ├── siglip.py │ ├── tokenizer.py │ ├── tokenizer_test.py │ ├── utils/ │ │ └── fsq_tokenizer.py │ └── vit.py ├── models_pytorch/ │ ├── gemma_pytorch.py │ ├── pi0_pytorch.py │ ├── preprocessing_pytorch.py │ └── transformers_replace/ │ └── models/ │ ├── gemma/ │ │ ├── configuration_gemma.py │ │ └── modeling_gemma.py │ ├── paligemma/ │ │ └── modeling_paligemma.py │ └── siglip/ │ ├── check.py │ └── modeling_siglip.py ├── policies/ │ ├── aloha_policy.py │ ├── droid_policy.py │ ├── libero_policy.py │ ├── policy.py │ ├── policy_config.py │ └── policy_test.py ├── py.typed ├── serving/ │ └── websocket_policy_server.py ├── shared/ │ ├── __init__.py │ ├── array_typing.py │ ├── download.py │ ├── download_test.py │ ├── image_tools.py │ ├── image_tools_test.py │ ├── nnx_utils.py │ ├── normalize.py │ └── normalize_test.py ├── training/ │ ├── checkpoints.py │ ├── config.py │ ├── data_loader.py │ ├── data_loader_test.py │ ├── droid_rlds_dataset.py │ ├── misc/ │ │ ├── polaris_config.py │ │ └── roboarena_config.py │ ├── optimizer.py │ ├── sharding.py │ ├── utils.py │ └── weight_loaders.py ├── transforms.py └── transforms_test.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .dockerignore ================================================ .venv checkpoints data ================================================ FILE: .github/CODEOWNERS ================================================ # The CODEOWNERS file defines individuals or teams that are automatically requested for # review when someone opens a pull request that modifies certain code. When a draft pull # request is marked as ready for review, code owners are automatically notified. # # See: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners # # This is a comment. # Each line is a file pattern followed by one or more owners. # Global owners. * @jimmyt857 @Michael-Equi @kvablack src/openpi/models/ @kvablack src/openpi/training/ @kvablack scripts/ @jimmyt857 @kvablack ================================================ FILE: .github/workflows/pre-commit.yml ================================================ name: pre-commit on: push: branches: - main pull_request: branches: - "*" jobs: pre-commit: runs-on: ubuntu-latest env: GIT_LFS_SKIP_SMUDGE: true steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v3 - uses: pre-commit/action@v3.0.1 ================================================ FILE: .github/workflows/test.yml ================================================ name: Test on: pull_request: branches: - "*" jobs: run_tests: name: Run Tests runs-on: openpi-verylarge env: GIT_LFS_SKIP_SMUDGE: true steps: - uses: actions/checkout@v4 - name: Install FFmpeg dependencies run: | sudo apt-get update sudo apt-get install -y ffmpeg libavcodec-dev libavformat-dev libavutil-dev - name: Install uv uses: astral-sh/setup-uv@v5 - name: Set up Python run: uv python install - name: Install the project run: uv sync --all-extras --dev - name: Run tests run: uv run pytest --strict-markers -m "not manual" ================================================ FILE: .gitignore ================================================ # Data directories. assets/ checkpoints/ data/ wandb/ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/latest/usage/project/#working-with-version-control .pdm.toml .pdm-python .pdm-build/ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ ================================================ FILE: .gitmodules ================================================ [submodule "third_party/aloha"] path = third_party/aloha url = https://github.com/Physical-Intelligence/aloha.git [submodule "third_party/libero"] path = third_party/libero url = https://github.com/Lifelong-Robot-Learning/LIBERO.git ================================================ FILE: .pre-commit-config.yaml ================================================ exclude: third_party/ repos: - repo: https://github.com/astral-sh/uv-pre-commit # uv version. rev: 0.5.14 hooks: - id: uv-lock - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. rev: v0.8.6 hooks: # Run the linter. - id: ruff args: [--fix] - id: ruff-format ================================================ FILE: .python-version ================================================ 3.11 ================================================ FILE: .vscode/settings.json ================================================ { "[python]": { "editor.defaultFormatter": "charliermarsh.ruff", "editor.formatOnSave": true, }, "python.testing.pytestArgs": [ "src" ], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true } ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to openpi We welcome contributions, improvements, and modifications. Everyone is welcome to use openpi in accordance to the [license](LICENSE). Contributors are also welcome to submit bug reports, feature requests, and pull requests. We can't promise to approve every pull request, and we are a small team with limited bandwidth to review all requests, but we'll give it our best effort. Specifics are described below. ## Issues and feature requests You are welcome to use the Github [discussion](https://github.com/Physical-Intelligence/openpi/discussions) feature if you would like to discuss something that is not directly reporting an issue or making a feature request. This is suitable for questions about how to use some aspect of openpi, or other topics. If you found a bug or other issue, please first check that the issue was not already reported (use the search bar on Github under Issues). If the issue has not yet been reported, please include this information when filing a Github issue: - Your OS type and version and the version of Python you are using - Code that allows us to reproduce your bug, including all dependencies - Traceback of any exception - Any other information that would help us, such as a screenshot In order for us to address any issue, we must be able to reproduce it, so if you encountered the issue after making modifications to openpi, please reproduce the issue without any other modifications and provide a code snippet that allows us to quickly reproduce the problem on `main`. If you would like to submit a feature request, please check that the feature request does not already exist, and please provide the following information: - The motivation for the feature - A description of the problem you are trying to solve or your use case - Enough information for us to understand the nature of the request - Some information for how you intend to use it (this might help us in understanding the motivation!) We can't promise to support every feature request, but it is helpful to us to know the use cases that you are interested in! ## Submitting a pull request If you implemented support for a new robot or environment, or some other new feature, we welcome pull requests (PRs) to openpi. We encourage you to create a [feature request](https://github.com/Physical-Intelligence/openpi/issues) or make a post on the [discussion](https://github.com/Physical-Intelligence/openpi/discussions) board before starting to work on your PR, if you would like to get a sense for whether we are likely to approve your PR if it is submitted. Since we are a small team with limited ability to provide maintenance and support, we may not accept all PRs (e.g., if we believe it would make the code harder to maintain, or if reviewing the PR is out of scope for us), so contacting us in advance is a good way to get a sense for whether your PR is likely to get approved for merging into openpi directly. But even if it isn't, you are of course more than welcome to maintain your own fork with whatever modifications you would like. When creating PRs, we recommend every contribution to consider the following: - Make sure that your PR has a clear title and description - Run `pre-commit` (install using `pre-commit install` first), and run `ruff check .` and `ruff format .` - Make sure your PR passes all tests ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: LICENSE_GEMMA.txt ================================================ Gemma Terms of Use Last modified: February 21, 2024 By using, reproducing, modifying, distributing, performing or displaying any portion or element of Gemma, Model Derivatives including via any Hosted Service, (each as defined below) (collectively, the "Gemma Services") or otherwise accepting the terms of this Agreement, you agree to be bound by this Agreement. Section 1: DEFINITIONS 1.1 Definitions (a) "Agreement" or "Gemma Terms of Use" means these terms and conditions that govern the use, reproduction, Distribution or modification of the Gemma Services and any terms and conditions incorporated by reference. (b) "Distribution" or "Distribute" means any transmission, publication, or other sharing of Gemma or Model Derivatives to a third party, including by providing or making Gemma or its functionality available as a hosted service via API, web access, or any other electronic or remote means ("Hosted Service"). (c) "Gemma" means the set of machine learning language models, trained model weights and parameters identified at ai.google.dev/gemma, regardless of the source that you obtained it from. (d) "Google" means Google LLC. (e) "Model Derivatives" means all (i) modifications to Gemma, (ii) works based on Gemma, or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Gemma, to that model in order to cause that model to perform similarly to Gemma, including distillation methods that use intermediate data representations or methods based on the generation of synthetic data Outputs by Gemma for training that model. For clarity, Outputs are not deemed Model Derivatives. (f) "Output" means the information content output of Gemma or a Model Derivative that results from operating or otherwise using Gemma or the Model Derivative, including via a Hosted Service. 1.2 As used in this Agreement, "including" means "including without limitation". Section 2: ELIGIBILITY AND USAGE 2.1 Eligibility You represent and warrant that you have the legal capacity to enter into this Agreement (including being of sufficient age of consent). If you are accessing or using any of the Gemma Services for or on behalf of a legal entity, (a) you are entering into this Agreement on behalf of yourself and that legal entity, (b) you represent and warrant that you have the authority to act on behalf of and bind that entity to this Agreement and (c) references to "you" or "your" in the remainder of this Agreement refers to both you (as an individual) and that entity. 2.2 Use You may use, reproduce, modify, Distribute, perform or display any of the Gemma Services only in accordance with the terms of this Agreement, and must not violate (or encourage or permit anyone else to violate) any term of this Agreement. Section 3: DISTRIBUTION AND RESTRICTIONS 3.1 Distribution and Redistribution You may reproduce or Distribute copies of Gemma or Model Derivatives if you meet all of the following conditions: You must include the use restrictions referenced in Section 3.2 as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Gemma or Model Derivatives and you must provide notice to subsequent users you Distribute to that Gemma or Model Derivatives are subject to the use restrictions in Section 3.2. You must provide all third party recipients of Gemma or Model Derivatives a copy of this Agreement. You must cause any modified files to carry prominent notices stating that you modified the files. All Distributions (other than through a Hosted Service) must be accompanied by a "Notice" text file that contains the following notice: "Gemma is provided under and subject to the Gemma Terms of Use found at ai.google.dev/gemma/terms". You may add your own intellectual property statement to your modifications and, except as set forth in this Section, may provide additional or different terms and conditions for use, reproduction, or Distribution of your modifications, or for any such Model Derivatives as a whole, provided your use, reproduction, modification, Distribution, performance, and display of Gemma otherwise complies with the terms and conditions of this Agreement. Any additional or different terms and conditions you impose must not conflict with the terms of this Agreement. 3.2 Use Restrictions You must not use any of the Gemma Services: for the restricted uses set forth in the Gemma Prohibited Use Policy at ai.google.dev/gemma/prohibited_use_policy ("Prohibited Use Policy"), which is hereby incorporated by reference into this Agreement; or in violation of applicable laws and regulations. To the maximum extent permitted by law, Google reserves the right to restrict (remotely or otherwise) usage of any of the Gemma Services that Google reasonably believes are in violation of this Agreement. 3.3 Generated Output Google claims no rights in Outputs you generate using Gemma. You and your users are solely responsible for Outputs and their subsequent uses. Section 4: ADDITIONAL PROVISIONS 4.1 Updates Google may update Gemma from time to time, and you must make reasonable efforts to use the latest version of Gemma. 4.2 Trademarks Nothing in this Agreement grants you any rights to use Google's trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between you and Google. Google reserves any rights not expressly granted herein. 4.3 DISCLAIMER OF WARRANTY UNLESS REQUIRED BY APPLICABLE LAW, THE GEMMA SERVICES, AND OUTPUTS, ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY WARRANTIES OR CONDITIONS OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR OR DISTRIBUTING ANY OF THE GEMMA SERVICES OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR USE OR DISTRIBUTION OF ANY OF THE GEMMA SERVICES OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT. 4.4 LIMITATION OF LIABILITY TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), PRODUCT LIABILITY, CONTRACT, OR OTHERWISE, UNLESS REQUIRED BY APPLICABLE LAW, SHALL GOOGLE OR ITS AFFILIATES BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL, OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO, ANY OF THE GEMMA SERVICES OR OUTPUTS EVEN IF GOOGLE OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 4.5 Term, Termination, and Survival The term of this Agreement will commence upon your acceptance of this Agreement (including acceptance by your use, modification, or Distribution, reproduction, performance or display of any portion or element of the Gemma Services) and will continue in full force and effect until terminated in accordance with the terms of this Agreement. Google may terminate this Agreement if you are in breach of any term of this Agreement. Upon termination of this Agreement, you must delete and cease use and Distribution of all copies of Gemma and Model Derivatives in your possession or control. Sections 1, 2.1, 3.3, 4.2 to 4.9 shall survive the termination of this Agreement. 4.6 Governing Law and Jurisdiction This Agreement will be governed by the laws of the State of California without regard to choice of law principles. The UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The state and federal courts of Santa Clara County, California shall have exclusive jurisdiction of any dispute arising out of this Agreement. 4.7 Severability If any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. 4.8 Entire Agreement This Agreement states all the terms agreed between the parties and supersedes all other agreements between the parties as of the date of acceptance relating to its subject matter. 4.9 No Waiver Google will not be treated as having waived any rights by not exercising (or delaying the exercise of) any rights under this Agreement. ================================================ FILE: README.md ================================================ # openpi openpi holds open-source models and packages for robotics, published by the [Physical Intelligence team](https://www.physicalintelligence.company/). Currently, this repo contains three types of models: - the [π₀ model](https://www.physicalintelligence.company/blog/pi0), a flow-based vision-language-action model (VLA). - the [π₀-FAST model](https://www.physicalintelligence.company/research/fast), an autoregressive VLA, based on the FAST action tokenizer. - the [π₀.₅ model](https://www.physicalintelligence.company/blog/pi05), an upgraded version of π₀ with better open-world generalization trained with [knowledge insulation](https://www.physicalintelligence.company/research/knowledge_insulation). Note that, in this repository, we currently only support the flow matching head for both $\pi_{0.5}$ training and inference. For all models, we provide _base model_ checkpoints, pre-trained on 10k+ hours of robot data, and examples for using them out of the box or fine-tuning them to your own datasets. This is an experiment: $\pi_0$ was developed for our own robots, which differ from the widely used platforms such as [ALOHA](https://tonyzhaozh.github.io/aloha/) and [DROID](https://droid-dataset.github.io/), and though we are optimistic that researchers and practitioners will be able to run creative new experiments adapting $\pi_0$ to their own platforms, we do not expect every such attempt to be successful. All this is to say: $\pi_0$ may or may not work for you, but you are welcome to try it and see! ## Updates - [Sept 2025] We released PyTorch support in openpi. - [Sept 2025] We released pi05, an upgraded version of pi0 with better open-world generalization. - [Sept 2025]: We have added an [improved idle filter](examples/droid/README_train.md#data-filtering) for DROID training. - [Jun 2025]: We have added [instructions](examples/droid/README_train.md) for using `openpi` to train VLAs on the full [DROID dataset](https://droid-dataset.github.io/). This is an approximate open-source implementation of the training pipeline used to train pi0-FAST-DROID. ## Requirements To run the models in this repository, you will need an NVIDIA GPU with at least the following specifications. These estimations assume a single GPU, but you can also use multiple GPUs with model parallelism to reduce per-GPU memory requirements by configuring `fsdp_devices` in the training config. Please also note that the current training script does not yet support multi-node training. | Mode | Memory Required | Example GPU | | ------------------ | --------------- | ------------------ | | Inference | > 8 GB | RTX 4090 | | Fine-Tuning (LoRA) | > 22.5 GB | RTX 4090 | | Fine-Tuning (Full) | > 70 GB | A100 (80GB) / H100 | The repo has been tested with Ubuntu 22.04, we do not currently support other operating systems. ## Installation When cloning this repo, make sure to update submodules: ```bash git clone --recurse-submodules git@github.com:Physical-Intelligence/openpi.git # Or if you already cloned the repo: git submodule update --init --recursive ``` We use [uv](https://docs.astral.sh/uv/) to manage Python dependencies. See the [uv installation instructions](https://docs.astral.sh/uv/getting-started/installation/) to set it up. Once uv is installed, run the following to set up the environment: ```bash GIT_LFS_SKIP_SMUDGE=1 uv sync GIT_LFS_SKIP_SMUDGE=1 uv pip install -e . ``` NOTE: `GIT_LFS_SKIP_SMUDGE=1` is needed to pull LeRobot as a dependency. **Docker**: As an alternative to uv installation, we provide instructions for installing openpi using Docker. If you encounter issues with your system setup, consider using Docker to simplify installation. See [Docker Setup](docs/docker.md) for more details. ## Model Checkpoints ### Base Models We provide multiple base VLA model checkpoints. These checkpoints have been pre-trained on 10k+ hours of robot data, and can be used for fine-tuning. | Model | Use Case | Description | Checkpoint Path | | ------------ | ----------- | ----------------------------------------------------------------------------------------------------------- | ---------------------------------------------- | | $\pi_0$ | Fine-Tuning | Base [π₀ model](https://www.physicalintelligence.company/blog/pi0) for fine-tuning | `gs://openpi-assets/checkpoints/pi0_base` | | $\pi_0$-FAST | Fine-Tuning | Base autoregressive [π₀-FAST model](https://www.physicalintelligence.company/research/fast) for fine-tuning | `gs://openpi-assets/checkpoints/pi0_fast_base` | | $\pi_{0.5}$ | Fine-Tuning | Base [π₀.₅ model](https://www.physicalintelligence.company/blog/pi05) for fine-tuning | `gs://openpi-assets/checkpoints/pi05_base` | ### Fine-Tuned Models We also provide "expert" checkpoints for various robot platforms and tasks. These models are fine-tuned from the base models above and intended to run directly on the target robot. These may or may not work on your particular robot. Since these checkpoints were fine-tuned on relatively small datasets collected with more widely available robots, such as ALOHA and the DROID Franka setup, they might not generalize to your particular setup, though we found some of these, especially the DROID checkpoint, to generalize quite broadly in practice. | Model | Use Case | Description | Checkpoint Path | | ------------------------ | ----------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------- | | $\pi_0$-FAST-DROID | Inference | $\pi_0$-FAST model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/): can perform a wide range of simple table-top manipulation tasks 0-shot in new scenes on the DROID robot platform | `gs://openpi-assets/checkpoints/pi0_fast_droid` | | $\pi_0$-DROID | Fine-Tuning | $\pi_0$ model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/): faster inference than $\pi_0$-FAST-DROID, but may not follow language commands as well | `gs://openpi-assets/checkpoints/pi0_droid` | | $\pi_0$-ALOHA-towel | Inference | $\pi_0$ model fine-tuned on internal [ALOHA](https://tonyzhaozh.github.io/aloha/) data: can fold diverse towels 0-shot on ALOHA robot platforms | `gs://openpi-assets/checkpoints/pi0_aloha_towel` | | $\pi_0$-ALOHA-tupperware | Inference | $\pi_0$ model fine-tuned on internal [ALOHA](https://tonyzhaozh.github.io/aloha/) data: can unpack food from a tupperware container | `gs://openpi-assets/checkpoints/pi0_aloha_tupperware` | | $\pi_0$-ALOHA-pen-uncap | Inference | $\pi_0$ model fine-tuned on public [ALOHA](https://dit-policy.github.io/) data: can uncap a pen | `gs://openpi-assets/checkpoints/pi0_aloha_pen_uncap` | | $\pi_{0.5}$-LIBERO | Inference | $\pi_{0.5}$ model fine-tuned for the [LIBERO](https://libero-project.github.io/datasets) benchmark: gets state-of-the-art performance (see [LIBERO README](examples/libero/README.md)) | `gs://openpi-assets/checkpoints/pi05_libero` | | $\pi_{0.5}$-DROID | Inference / Fine-Tuning | $\pi_{0.5}$ model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/) with [knowledge insulation](https://www.physicalintelligence.company/research/knowledge_insulation): fast inference and good language-following | `gs://openpi-assets/checkpoints/pi05_droid` | By default, checkpoints are automatically downloaded from `gs://openpi-assets` and are cached in `~/.cache/openpi` when needed. You can overwrite the download path by setting the `OPENPI_DATA_HOME` environment variable. ## Running Inference for a Pre-Trained Model Our pre-trained model checkpoints can be run with a few lines of code (here our $\pi_0$-FAST-DROID model): ```python from openpi.training import config as _config from openpi.policies import policy_config from openpi.shared import download config = _config.get_config("pi05_droid") checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi05_droid") # Create a trained policy. policy = policy_config.create_trained_policy(config, checkpoint_dir) # Run inference on a dummy example. example = { "observation/exterior_image_1_left": ..., "observation/wrist_image_left": ..., ... "prompt": "pick up the fork" } action_chunk = policy.infer(example)["actions"] ``` You can also test this out in the [example notebook](examples/inference.ipynb). We provide detailed step-by-step examples for running inference of our pre-trained checkpoints on [DROID](examples/droid/README.md) and [ALOHA](examples/aloha_real/README.md) robots. **Remote Inference**: We provide [examples and code](docs/remote_inference.md) for running inference of our models **remotely**: the model can run on a different server and stream actions to the robot via a websocket connection. This makes it easy to use more powerful GPUs off-robot and keep robot and policy environments separate. **Test inference without a robot**: We provide a [script](examples/simple_client/README.md) for testing inference without a robot. This script will generate a random observation and run inference with the model. See [here](examples/simple_client/README.md) for more details. ## Fine-Tuning Base Models on Your Own Data We will fine-tune the $\pi_{0.5}$ model on the [LIBERO dataset](https://libero-project.github.io/datasets) as a running example for how to fine-tune a base model on your own data. We will explain three steps: 1. Convert your data to a LeRobot dataset (which we use for training) 2. Defining training configs and running training 3. Spinning up a policy server and running inference ### 1. Convert your data to a LeRobot dataset We provide a minimal example script for converting LIBERO data to a LeRobot dataset in [`examples/libero/convert_libero_data_to_lerobot.py`](examples/libero/convert_libero_data_to_lerobot.py). You can easily modify it to convert your own data! You can download the raw LIBERO dataset from [here](https://huggingface.co/datasets/openvla/modified_libero_rlds), and run the script with: ```bash uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/libero/data ``` **Note:** If you just want to fine-tune on LIBERO, you can skip this step, because our LIBERO fine-tuning configs point to a pre-converted LIBERO dataset. This step is merely an example that you can adapt to your own data. ### 2. Defining training configs and running training To fine-tune a base model on your own data, you need to define configs for data processing and training. We provide example configs with detailed comments for LIBERO below, which you can modify for your own dataset: - [`LiberoInputs` and `LiberoOutputs`](src/openpi/policies/libero_policy.py): Defines the data mapping from the LIBERO environment to the model and vice versa. Will be used for both, training and inference. - [`LeRobotLiberoDataConfig`](src/openpi/training/config.py): Defines how to process raw LIBERO data from LeRobot dataset for training. - [`TrainConfig`](src/openpi/training/config.py): Defines fine-tuning hyperparameters, data config, and weight loader. We provide example fine-tuning configs for [π₀](src/openpi/training/config.py), [π₀-FAST](src/openpi/training/config.py), and [π₀.₅](src/openpi/training/config.py) on LIBERO data. Before we can run training, we need to compute the normalization statistics for the training data. Run the script below with the name of your training config: ```bash uv run scripts/compute_norm_stats.py --config-name pi05_libero ``` Now we can kick off training with the following command (the `--overwrite` flag is used to overwrite existing checkpoints if you rerun fine-tuning with the same config): ```bash XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_libero --exp-name=my_experiment --overwrite ``` The command will log training progress to the console and save checkpoints to the `checkpoints` directory. You can also monitor training progress on the Weights & Biases dashboard. For maximally using the GPU memory, set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` before running training -- this enables JAX to use up to 90% of the GPU memory (vs. the default of 75%). **Note:** We provide functionality for *reloading* normalization statistics for state / action normalization from pre-training. This can be beneficial if you are fine-tuning to a new task on a robot that was part of our pre-training mixture. For more details on how to reload normalization statistics, see the [norm_stats.md](docs/norm_stats.md) file. ### 3. Spinning up a policy server and running inference Once training is complete, we can run inference by spinning up a policy server and then querying it from a LIBERO evaluation script. Launching a model server is easy (we use the checkpoint for iteration 20,000 for this example, modify as needed): ```bash uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_libero --policy.dir=checkpoints/pi05_libero/my_experiment/20000 ``` This will spin up a server that listens on port 8000 and waits for observations to be sent to it. We can then run an evaluation script (or robot runtime) that queries the server. For running the LIBERO eval in particular, we provide (and recommend using) a Dockerized workflow that handles both the policy server and the evaluation script together. See the [LIBERO README](examples/libero/README.md) for more details. If you want to embed a policy server call in your own robot runtime, we have a minimal example of how to do so in the [remote inference docs](docs/remote_inference.md). ### More Examples We provide more examples for how to fine-tune and run inference with our models on the ALOHA platform in the following READMEs: - [ALOHA Simulator](examples/aloha_sim) - [ALOHA Real](examples/aloha_real) - [UR5](examples/ur5) ## PyTorch Support openpi now provides PyTorch implementations of π₀ and π₀.₅ models alongside the original JAX versions! The PyTorch implementation has been validated on the LIBERO benchmark (both inference and finetuning). A few features are currently not supported (this may change in the future): - The π₀-FAST model - Mixed precision training - FSDP (fully-sharded data parallelism) training - LoRA (low-rank adaptation) training - EMA (exponential moving average) weights during training ### Setup 1. Make sure that you have the latest version of all dependencies installed: `uv sync` 2. Double check that you have transformers 4.53.2 installed: `uv pip show transformers` 3. Apply the transformers library patches: ```bash cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/ ``` This overwrites several files in the transformers library with necessary model changes: 1) supporting AdaRMS, 2) correctly controlling the precision of activations, and 3) allowing the KV cache to be used without being updated. **WARNING**: With the default uv link mode (hardlink), this will permanently affect the transformers library in your uv cache, meaning the changes will survive reinstallations of transformers and could even propagate to other projects that use transformers. To fully undo this operation, you must run `uv cache clean transformers`. ### Converting JAX Models to PyTorch To convert a JAX model checkpoint to PyTorch format: ```bash uv run examples/convert_jax_model_to_pytorch.py \ --checkpoint_dir /path/to/jax/checkpoint \ --config_name \ --output_path /path/to/converted/pytorch/checkpoint ``` ### Running Inference with PyTorch The PyTorch implementation uses the same API as the JAX version - you only need to change the checkpoint path to point to the converted PyTorch model: ```python from openpi.training import config as _config from openpi.policies import policy_config from openpi.shared import download config = _config.get_config("pi05_droid") checkpoint_dir = "/path/to/converted/pytorch/checkpoint" # Create a trained policy (automatically detects PyTorch format) policy = policy_config.create_trained_policy(config, checkpoint_dir) # Run inference (same API as JAX) action_chunk = policy.infer(example)["actions"] ``` ### Policy Server with PyTorch The policy server works identically with PyTorch models - just point to the converted checkpoint directory: ```bash uv run scripts/serve_policy.py policy:checkpoint \ --policy.config=pi05_droid \ --policy.dir=/path/to/converted/pytorch/checkpoint ``` ### Finetuning with PyTorch To finetune a model in PyTorch: 1. Convert the JAX base model to PyTorch format: ```bash uv run examples/convert_jax_model_to_pytorch.py \ --config_name \ --checkpoint_dir /path/to/jax/base/model \ --output_path /path/to/pytorch/base/model ``` 2. Specify the converted PyTorch model path in your config using `pytorch_weight_path` 3. Launch training using one of these modes: ```bash # Single GPU training: uv run scripts/train_pytorch.py --exp_name --save_interval # Example: uv run scripts/train_pytorch.py debug --exp_name pytorch_test uv run scripts/train_pytorch.py debug --exp_name pytorch_test --resume # Resume from latest checkpoint # Multi-GPU training (single node): uv run torchrun --standalone --nnodes=1 --nproc_per_node= scripts/train_pytorch.py --exp_name # Example: uv run torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test uv run torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume # Multi-Node Training: uv run torchrun \ --nnodes= \ --nproc_per_node= \ --node_rank= \ --master_addr= \ --master_port= \ scripts/train_pytorch.py --exp_name= --save_interval ``` ### Precision Settings JAX and PyTorch implementations handle precision as follows: **JAX:** 1. Inference: most weights and computations in bfloat16, with a few computations in float32 for stability 2. Training: defaults to mixed precision: weights and gradients in float32, (most) activations and computations in bfloat16. You can change to full float32 training by setting `dtype` to float32 in the config. **PyTorch:** 1. Inference: matches JAX -- most weights and computations in bfloat16, with a few weights converted to float32 for stability 2. Training: supports either full bfloat16 (default) or full float32. You can change it by setting `pytorch_training_precision` in the config. bfloat16 uses less memory but exhibits higher losses compared to float32. Mixed precision is not yet supported. With torch.compile, inference speed is comparable between JAX and PyTorch. ## Troubleshooting We will collect common issues and their solutions here. If you encounter an issue, please check here first. If you can't find a solution, please file an issue on the repo (see [here](CONTRIBUTING.md) for guidelines). | Issue | Resolution | | ----------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | `uv sync` fails with dependency conflicts | Try removing the virtual environment directory (`rm -rf .venv`) and running `uv sync` again. If issues persist, check that you have the latest version of `uv` installed (`uv self update`). | | Training runs out of GPU memory | Make sure you set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` (or higher) before running training to allow JAX to use more GPU memory. You can also use `--fsdp-devices ` where `` is your number of GPUs, to enable [fully-sharded data parallelism](https://engineering.fb.com/2021/07/15/open-source/fsdp/), which reduces memory usage in exchange for slower training (the amount of slowdown depends on your particular setup). If you are still running out of memory, you may want to consider disabling EMA. | | Policy server connection errors | Check that the server is running and listening on the expected port. Verify network connectivity and firewall settings between client and server. | | Missing norm stats error when training | Run `scripts/compute_norm_stats.py` with your config name before starting training. | | Dataset download fails | Check your internet connection. For HuggingFace datasets, ensure you're logged in (`huggingface-cli login`). | | CUDA/GPU errors | Verify NVIDIA drivers are installed correctly. For Docker, ensure nvidia-container-toolkit is installed. Check GPU compatibility. You do NOT need CUDA libraries installed at a system level --- they will be installed via uv. You may even want to try *uninstalling* system CUDA libraries if you run into CUDA issues, since system libraries can sometimes cause conflicts. | | Import errors when running examples | Make sure you've installed all dependencies with `uv sync`. Some examples may have additional requirements listed in their READMEs. | | Action dimensions mismatch | Verify your data processing transforms match the expected input/output dimensions of your robot. Check the action space definitions in your policy classes. | | Diverging training loss | Check the `q01`, `q99`, and `std` values in `norm_stats.json` for your dataset. Certain dimensions that are rarely used can end up with very small `q01`, `q99`, or `std` values, leading to huge states and actions after normalization. You can manually adjust the norm stats as a workaround. | ================================================ FILE: docs/docker.md ================================================ ### Docker Setup All of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS. - Basic Docker installation instructions are [here](https://docs.docker.com/engine/install/). - Docker must be installed in [rootless mode](https://docs.docker.com/engine/security/rootless/). - To use your GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). - The version of docker installed with `snap` is incompatible with the NVIDIA container toolkit, preventing it from accessing `libnvidia-ml.so` ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/154)). The snap version can be uninstalled with `sudo snap remove docker`. - Docker Desktop is also incompatible with the NVIDIA runtime ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/229)). Docker Desktop can be uninstalled with `sudo apt remove docker-desktop`. If starting from scratch and your host machine is Ubuntu 22.04, you can use accomplish all of the above with the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`. Build the Docker image and start the container with the following command: ```bash docker compose -f scripts/docker/compose.yml up --build ``` To build and run the Docker image for a specific example, use the following command: ```bash docker compose -f examples//compose.yml up --build ``` where `` is the name of the example you want to run. During the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached. ================================================ FILE: docs/norm_stats.md ================================================ # Normalization statistics Following common practice, our models normalize the proprioceptive state inputs and action targets during policy training and inference. The statistics used for normalization are computed over the training data and stored alongside the model checkpoint. ## Reloading normalization statistics When you fine-tune one of our models on a new dataset, you need to decide whether to (A) reuse existing normalization statistics or (B) compute new statistics over your new training data. Which option is better for you depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. Below, we list all the available pre-training normalization statistics for each model. **If your target robot matches one of these pre-training statistics, consider reloading the same normalization statistics.** By reloading the normalization statistics, the actions in your dataset will be more "familiar" to the model, which can lead to better performance. You can reload the normalization statistics by adding an `AssetsConfig` to your training config that points to the corresponding checkpoint directory and normalization statistics ID, like below for the `Trossen` (aka ALOHA) robot statistics of the `pi0_base` checkpoint: ```python TrainConfig( ... data=LeRobotAlohaDataConfig( ... assets=AssetsConfig( assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets", asset_id="trossen", ), ), ) ``` For an example of a full training config that reloads normalization statistics, see the `pi0_aloha_pen_uncap` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py). **Note:** To successfully reload normalization statistics, it's important that your robot + dataset are following the action space definitions used in pre-training. We provide a detailed description of our action space definitions below. **Note #2:** Whether reloading normalization statistics is beneficial depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. We recommend to always try both, reloading and training with a fresh set of statistics computed on your new dataset (see [main README](../README.md) for instructions on how to compute new statistics), and pick the one that works better for your task. ## Provided Pre-training Normalization Statistics Below is a list of all the pre-training normalization statistics we provide. We provide them for both, the `pi0_base` and `pi0_fast_base` models. For `pi0_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_base/assets` and for `pi0_fast_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_fast_base/assets`. | Robot | Description | Asset ID | |-------|-------------|----------| | ALOHA | 6-DoF dual arm robot with parallel grippers | trossen | | Mobile ALOHA | Mobile version of ALOHA mounted on a Slate base | trossen_mobile | | Franka Emika (DROID) | 7-DoF arm with parallel gripper based on the DROID setup | droid | | Franka Emika (non-DROID) | Franka FR3 arm with Robotiq 2F-85 gripper | franka | | UR5e | 6-DoF UR5e arm with Robotiq 2F-85 gripper | ur5e | | UR5e bi-manual | Bi-manual UR5e setup with Robotiq 2F-85 grippers | ur5e_dual | | ARX | Bi-manual ARX-5 robot arm setup with parallel gripper | arx | | ARX mobile | Mobile version of bi-manual ARX-5 robot arm setup mounted on a Slate base | arx_mobile | | Fibocom mobile | Fibocom mobile robot with 2x ARX-5 arms | fibocom_mobile | ## Pi0 Model Action Space Definitions Out of the box, both the `pi0_base` and `pi0_fast_base` use the following action space definitions (left and right are defined looking from behind the robot towards the workspace): ``` "dim_0:dim_5": "left arm joint angles", "dim_6": "left arm gripper position", "dim_7:dim_12": "right arm joint angles (for bi-manual only)", "dim_13": "right arm gripper position (for bi-manual only)", # For mobile robots: "dim_14:dim_15": "x-y base velocity (for mobile robots only)", ``` The proprioceptive state uses the same definitions as the action space, except for the base x-y position (the last two dimensions) for mobile robots, which we don't include in the proprioceptive state. For 7-DoF robots (e.g. Franka), we use the first 7 dimensions of the action space for the joint actions, and the 8th dimension for the gripper action. General info for Pi robots: - Joint angles are expressed in radians, with position zero corresponding to the zero position reported by each robot's interface library, except for ALOHA, where the standard ALOHA code uses a slightly different convention (see the [ALOHA example code](../examples/aloha_real/README.md) for details). - Gripper positions are in [0.0, 1.0], with 0.0 corresponding to fully open and 1.0 corresponding to fully closed. - Control frequencies are either 20 Hz for UR5e and Franka, and 50 Hz for ARX and Trossen (ALOHA) arms. For DROID, we use the original DROID action configuration, with joint velocity actions in the first 7 dimensions and gripper actions in the 8th dimension + a control frequency of 15 Hz. ================================================ FILE: docs/remote_inference.md ================================================ # Running openpi models remotely We provide utilities for running openpi models remotely. This is useful for running inference on more powerful GPUs off-robot, and also helps keep the robot and policy environments separate (and e.g. avoid dependency hell with robot software). ## Starting a remote policy server To start a remote policy server, you can simply run the following command: ```bash uv run scripts/serve_policy.py --env=[DROID | ALOHA | LIBERO] ``` The `env` argument specifies which $\pi_0$ checkpoint should be loaded. Under the hood, this script will execute a command like the following, which you can use to start a policy server, e.g. for checkpoints you trained yourself (here an example for the DROID environment): ```bash uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid ``` This will start a policy server that will serve the policy specified by the `config` and `dir` arguments. The policy will be served on the specified port (default: 8000). ## Querying the remote policy server from your robot code We provide a client utility with minimal dependencies that you can easily embed into any robot codebase. First, install the `openpi-client` package in your robot environment: ```bash cd $OPENPI_ROOT/packages/openpi-client pip install -e . ``` Then, you can use the client to query the remote policy server from your robot code. Here's an example of how to do this: ```python from openpi_client import image_tools from openpi_client import websocket_client_policy # Outside of episode loop, initialize the policy client. # Point to the host and port of the policy server (localhost and 8000 are the defaults). client = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000) for step in range(num_steps): # Inside the episode loop, construct the observation. # Resize images on the client side to minimize bandwidth / latency. Always return images in uint8 format. # We provide utilities for resizing images + uint8 conversion so you match the training routines. # The typical resize_size for pre-trained pi0 models is 224. # Note that the proprioceptive `state` can be passed unnormalized, normalization will be handled on the server side. observation = { "observation/image": image_tools.convert_to_uint8( image_tools.resize_with_pad(img, 224, 224) ), "observation/wrist_image": image_tools.convert_to_uint8( image_tools.resize_with_pad(wrist_img, 224, 224) ), "observation/state": state, "prompt": task_instruction, } # Call the policy server with the current observation. # This returns an action chunk of shape (action_horizon, action_dim). # Note that you typically only need to call the policy every N steps and execute steps # from the predicted action chunk open-loop in the remaining steps. action_chunk = client.infer(observation)["actions"] # Execute the actions in the environment. ... ``` Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `observation` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](../examples/simple_client/main.py). ================================================ FILE: examples/aloha_real/Dockerfile ================================================ # Dockerfile for the Aloha real environment. # Build the container: # docker build . -t aloha_real -f examples/aloha_real/Dockerfile # Run the container: # docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash FROM ros:noetic-robot@sha256:7cf0b9f6546abeba308ea42cb7ad3453f3e520e1af57cdf179fe915c939674bc SHELL ["/bin/bash", "-c"] ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update && \ apt-get install -y --no-install-recommends \ cmake \ curl \ libffi-dev \ python3-rosdep \ python3-rosinstall \ python3-rosinstall-generator \ whiptail \ git \ wget \ openssh-client \ ros-noetic-cv-bridge \ ros-noetic-usb-cam \ ros-noetic-realsense2-camera \ keyboard-configuration WORKDIR /root RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh RUN chmod +x xsarm_amd64_install.sh RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n COPY ./third_party/aloha /root/interbotix_ws/src/aloha RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make # Install python 3.10 because this ROS image comes with 3.8 RUN mkdir /python && \ cd /python && \ wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \ tar -zxvf Python-3.10.14.tgz && \ cd Python-3.10.14 && \ ls -lhR && \ ./configure --enable-optimizations && \ make install && \ echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \ echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \ cd ~ && rm -rf /python && \ rm -rf /var/lib/apt/lists/* COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv ENV UV_HTTP_TIMEOUT=120 ENV UV_LINK_MODE=copy COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha WORKDIR /app # Create an entrypoint script to run the setup commands, followed by the command passed in. RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh #!/bin/bash source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@" EOF RUN chmod +x /usr/local/bin/entrypoint.sh ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] CMD ["python3", "/app/examples/aloha_real/main.py"] ================================================ FILE: examples/aloha_real/README.md ================================================ # Run Aloha (Real Robot) This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below. ## Prerequisites This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras. 1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo. 1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras. ## With Docker ```bash export SERVER_ARGS="--env ALOHA --default_prompt='take the toast out of the toaster'" docker compose -f examples/aloha_real/compose.yml up --build ``` ## Without Docker Terminal window 1: ```bash # Create virtual environment uv venv --python 3.10 examples/aloha_real/.venv source examples/aloha_real/.venv/bin/activate uv pip sync examples/aloha_real/requirements.txt uv pip install -e packages/openpi-client # Run the robot python -m examples.aloha_real.main ``` Terminal window 2: ```bash roslaunch aloha ros_nodes.launch ``` Terminal window 3: ```bash uv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster' ``` ## **ALOHA Checkpoint Guide** The `pi0_base` model can be used in zero shot for a simple task on the ALOHA platform, and we additionally provide two example fine-tuned checkpoints, “fold the towel” and “open the tupperware and put the food on the plate,” which can perform more advanced tasks on the ALOHA. While we’ve found the policies to work in unseen conditions across multiple ALOHA stations, we provide some pointers here on how best to set up scenes to maximize the chance of policy success. We cover the prompts to use for the policies, objects we’ve seen it work well on, and well-represented initial state distributions. Running these policies in zero shot is still a very experimental feature, and there is no guarantee that they will work on your robot. The recommended way to use `pi0_base` is by finetuning with data from the target robot. --- ### **Toast Task** This task involves the robot taking two pieces of toast out of a toaster and placing them on a plate. - **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_base` - **Prompt**: "take the toast out of the toaster" - **Objects needed**: Two pieces of toast, a plate, and a standard toaster. - **Object Distribution**: - Works on both real toast and rubber fake toast - Compatible with standard 2-slice toasters - Works with plates of varying colors ### **Scene Setup Guidelines** Screenshot 2025-01-31 at 10 06 02 PM - The toaster should be positioned in the top-left quadrant of the workspace. - Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top. - The plate should be placed roughly in the lower-center of the workspace. - Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain). ### **Towel Task** This task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths. - **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_towel` - **Prompt**: "fold the towel" - **Object Distribution**: - Works on towels of varying solid colors - Performance is worse on heavily textured or striped towels ### **Scene Setup Guidelines** Screenshot 2025-01-31 at 10 01 15 PM - The towel should be flattened and roughly centered on the table. - Choose a towel that does not blend in with the table surface. ### **Tupperware Task** This task involves opening a tupperware filled with food and pouring the contents onto a plate. - **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_tupperware` - **Prompt**: "open the tupperware and put the food on the plate" - **Objects needed**: Tupperware, food (or food-like items), and a plate. - **Object Distribution**: - Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken). - Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below). - The policy has seen plates of varying solid colors. ### **Scene Setup Guidelines** Screenshot 2025-01-31 at 10 02 27 PM - Best performance observed when both the tupperware and plate are roughly centered in the workspace. - Positioning: - Tupperware should be on the left. - Plate should be on the right or bottom. - The tupperware flap should point toward the plate. ## Training on your own Aloha dataset 1. Convert the dataset to the LeRobot dataset v2.0 format. We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse). 2. Define a training config that uses the custom dataset. We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config. IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig. ================================================ FILE: examples/aloha_real/compose.yml ================================================ # Run with: # docker compose -f examples/aloha_real/compose.yml up --build services: runtime: image: aloha_real depends_on: - aloha_ros_nodes - ros_master - openpi_server build: context: ../.. dockerfile: examples/aloha_real/Dockerfile init: true tty: true network_mode: host privileged: true volumes: - $PWD:/app - ../../data:/data aloha_ros_nodes: image: aloha_real depends_on: - ros_master build: context: ../.. dockerfile: examples/aloha_real/Dockerfile init: true tty: true network_mode: host privileged: true volumes: - /dev:/dev command: roslaunch --wait aloha ros_nodes.launch ros_master: image: ros:noetic-robot network_mode: host privileged: true command: - roscore openpi_server: image: openpi_server build: context: ../.. dockerfile: scripts/docker/serve_policy.Dockerfile init: true tty: true network_mode: host volumes: - $PWD:/app - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets environment: - SERVER_ARGS - OPENPI_DATA_HOME=/openpi_assets - IS_DOCKER=true # Comment out this block if not running on a machine with GPUs. deploy: resources: reservations: devices: - driver: nvidia count: 1 capabilities: [gpu] ================================================ FILE: examples/aloha_real/constants.py ================================================ # Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act). # ruff: noqa ### Task parameters ### ALOHA fixed constants DT = 0.001 JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239] # Left finger position limits (qpos[7]), right_finger = -1 * left_finger MASTER_GRIPPER_POSITION_OPEN = 0.02417 MASTER_GRIPPER_POSITION_CLOSE = 0.01244 PUPPET_GRIPPER_POSITION_OPEN = 0.05800 PUPPET_GRIPPER_POSITION_CLOSE = 0.01844 # Gripper joint limits (qpos[6]) MASTER_GRIPPER_JOINT_OPEN = 0.3083 MASTER_GRIPPER_JOINT_CLOSE = -0.6842 PUPPET_GRIPPER_JOINT_OPEN = 1.4910 PUPPET_GRIPPER_JOINT_CLOSE = -0.6213 ############################ Helper functions ############################ MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / ( MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE ) PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / ( PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE ) MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = ( lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE ) PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = ( lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE ) MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x)) MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / ( MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE ) PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / ( PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE ) MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = ( lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE ) PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = ( lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE ) MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x)) MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) MASTER_POS2JOINT = ( lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE ) MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN( (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) ) PUPPET_POS2JOINT = ( lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE ) PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN( (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) ) MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2 ================================================ FILE: examples/aloha_real/convert_aloha_data_to_lerobot.py ================================================ """ Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format. Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id / """ import dataclasses from pathlib import Path import shutil from typing import Literal import h5py from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw import numpy as np import torch import tqdm import tyro @dataclasses.dataclass(frozen=True) class DatasetConfig: use_videos: bool = True tolerance_s: float = 0.0001 image_writer_processes: int = 10 image_writer_threads: int = 5 video_backend: str | None = None DEFAULT_DATASET_CONFIG = DatasetConfig() def create_empty_dataset( repo_id: str, robot_type: str, mode: Literal["video", "image"] = "video", *, has_velocity: bool = False, has_effort: bool = False, dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG, ) -> LeRobotDataset: motors = [ "right_waist", "right_shoulder", "right_elbow", "right_forearm_roll", "right_wrist_angle", "right_wrist_rotate", "right_gripper", "left_waist", "left_shoulder", "left_elbow", "left_forearm_roll", "left_wrist_angle", "left_wrist_rotate", "left_gripper", ] cameras = [ "cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist", ] features = { "observation.state": { "dtype": "float32", "shape": (len(motors),), "names": [ motors, ], }, "action": { "dtype": "float32", "shape": (len(motors),), "names": [ motors, ], }, } if has_velocity: features["observation.velocity"] = { "dtype": "float32", "shape": (len(motors),), "names": [ motors, ], } if has_effort: features["observation.effort"] = { "dtype": "float32", "shape": (len(motors),), "names": [ motors, ], } for cam in cameras: features[f"observation.images.{cam}"] = { "dtype": mode, "shape": (3, 480, 640), "names": [ "channels", "height", "width", ], } if Path(LEROBOT_HOME / repo_id).exists(): shutil.rmtree(LEROBOT_HOME / repo_id) return LeRobotDataset.create( repo_id=repo_id, fps=50, robot_type=robot_type, features=features, use_videos=dataset_config.use_videos, tolerance_s=dataset_config.tolerance_s, image_writer_processes=dataset_config.image_writer_processes, image_writer_threads=dataset_config.image_writer_threads, video_backend=dataset_config.video_backend, ) def get_cameras(hdf5_files: list[Path]) -> list[str]: with h5py.File(hdf5_files[0], "r") as ep: # ignore depth channel, not currently handled return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118 def has_velocity(hdf5_files: list[Path]) -> bool: with h5py.File(hdf5_files[0], "r") as ep: return "/observations/qvel" in ep def has_effort(hdf5_files: list[Path]) -> bool: with h5py.File(hdf5_files[0], "r") as ep: return "/observations/effort" in ep def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]: imgs_per_cam = {} for camera in cameras: uncompressed = ep[f"/observations/images/{camera}"].ndim == 4 if uncompressed: # load all images in RAM imgs_array = ep[f"/observations/images/{camera}"][:] else: import cv2 # load one compressed image after the other in RAM and uncompress imgs_array = [] for data in ep[f"/observations/images/{camera}"]: imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB)) imgs_array = np.array(imgs_array) imgs_per_cam[camera] = imgs_array return imgs_per_cam def load_raw_episode_data( ep_path: Path, ) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: with h5py.File(ep_path, "r") as ep: state = torch.from_numpy(ep["/observations/qpos"][:]) action = torch.from_numpy(ep["/action"][:]) velocity = None if "/observations/qvel" in ep: velocity = torch.from_numpy(ep["/observations/qvel"][:]) effort = None if "/observations/effort" in ep: effort = torch.from_numpy(ep["/observations/effort"][:]) imgs_per_cam = load_raw_images_per_camera( ep, [ "cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist", ], ) return imgs_per_cam, state, action, velocity, effort def populate_dataset( dataset: LeRobotDataset, hdf5_files: list[Path], task: str, episodes: list[int] | None = None, ) -> LeRobotDataset: if episodes is None: episodes = range(len(hdf5_files)) for ep_idx in tqdm.tqdm(episodes): ep_path = hdf5_files[ep_idx] imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path) num_frames = state.shape[0] for i in range(num_frames): frame = { "observation.state": state[i], "action": action[i], } for camera, img_array in imgs_per_cam.items(): frame[f"observation.images.{camera}"] = img_array[i] if velocity is not None: frame["observation.velocity"] = velocity[i] if effort is not None: frame["observation.effort"] = effort[i] dataset.add_frame(frame) dataset.save_episode(task=task) return dataset def port_aloha( raw_dir: Path, repo_id: str, raw_repo_id: str | None = None, task: str = "DEBUG", *, episodes: list[int] | None = None, push_to_hub: bool = True, is_mobile: bool = False, mode: Literal["video", "image"] = "image", dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG, ): if (LEROBOT_HOME / repo_id).exists(): shutil.rmtree(LEROBOT_HOME / repo_id) if not raw_dir.exists(): if raw_repo_id is None: raise ValueError("raw_repo_id must be provided if raw_dir does not exist") download_raw(raw_dir, repo_id=raw_repo_id) hdf5_files = sorted(raw_dir.glob("episode_*.hdf5")) dataset = create_empty_dataset( repo_id, robot_type="mobile_aloha" if is_mobile else "aloha", mode=mode, has_effort=has_effort(hdf5_files), has_velocity=has_velocity(hdf5_files), dataset_config=dataset_config, ) dataset = populate_dataset( dataset, hdf5_files, task=task, episodes=episodes, ) dataset.consolidate() if push_to_hub: dataset.push_to_hub() if __name__ == "__main__": tyro.cli(port_aloha) ================================================ FILE: examples/aloha_real/env.py ================================================ from typing import List, Optional # noqa: UP035 import einops from openpi_client import image_tools from openpi_client.runtime import environment as _environment from typing_extensions import override from examples.aloha_real import real_env as _real_env class AlohaRealEnvironment(_environment.Environment): """An environment for an Aloha robot on real hardware.""" def __init__( self, reset_position: Optional[List[float]] = None, # noqa: UP006,UP007 render_height: int = 224, render_width: int = 224, ) -> None: self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position) self._render_height = render_height self._render_width = render_width self._ts = None @override def reset(self) -> None: self._ts = self._env.reset() @override def is_episode_complete(self) -> bool: return False @override def get_observation(self) -> dict: if self._ts is None: raise RuntimeError("Timestep is not set. Call reset() first.") obs = self._ts.observation for k in list(obs["images"].keys()): if "_depth" in k: del obs["images"][k] for cam_name in obs["images"]: img = image_tools.convert_to_uint8( image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width) ) obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w") return { "state": obs["qpos"], "images": obs["images"], } @override def apply_action(self, action: dict) -> None: self._ts = self._env.step(action["actions"]) ================================================ FILE: examples/aloha_real/main.py ================================================ import dataclasses import logging from openpi_client import action_chunk_broker from openpi_client import websocket_client_policy as _websocket_client_policy from openpi_client.runtime import runtime as _runtime from openpi_client.runtime.agents import policy_agent as _policy_agent import tyro from examples.aloha_real import env as _env @dataclasses.dataclass class Args: host: str = "0.0.0.0" port: int = 8000 action_horizon: int = 25 num_episodes: int = 1 max_episode_steps: int = 1000 def main(args: Args) -> None: ws_client_policy = _websocket_client_policy.WebsocketClientPolicy( host=args.host, port=args.port, ) logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}") metadata = ws_client_policy.get_server_metadata() runtime = _runtime.Runtime( environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")), agent=_policy_agent.PolicyAgent( policy=action_chunk_broker.ActionChunkBroker( policy=ws_client_policy, action_horizon=args.action_horizon, ) ), subscribers=[], max_hz=50, num_episodes=args.num_episodes, max_episode_steps=args.max_episode_steps, ) runtime.run() if __name__ == "__main__": logging.basicConfig(level=logging.INFO, force=True) tyro.cli(main) ================================================ FILE: examples/aloha_real/real_env.py ================================================ # Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act). # ruff: noqa import collections import time from typing import Optional, List import dm_env from interbotix_xs_modules.arm import InterbotixManipulatorXS from interbotix_xs_msgs.msg import JointSingleCommand import numpy as np from examples.aloha_real import constants from examples.aloha_real import robot_utils # This is the reset position that is used by the standard Aloha runtime. DEFAULT_RESET_POSITION = [0, -0.96, 1.16, 0, -0.3, 0] class RealEnv: """ Environment for real robot bi-manual manipulation Action space: [left_arm_qpos (6), # absolute joint position left_gripper_positions (1), # normalized gripper position (0: close, 1: open) right_arm_qpos (6), # absolute joint position right_gripper_positions (1),] # normalized gripper position (0: close, 1: open) Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position left_gripper_position (1), # normalized gripper position (0: close, 1: open) right_arm_qpos (6), # absolute joint position right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open) "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad) left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing) right_arm_qvel (6), # absolute joint velocity (rad) right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) "images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8' "cam_low": (480x640x3), # h, w, c, dtype='uint8' "cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8' "cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8' """ def __init__(self, init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True): # reset_position = START_ARM_POSE[:6] self._reset_position = reset_position[:6] if reset_position else DEFAULT_RESET_POSITION self.puppet_bot_left = InterbotixManipulatorXS( robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_left", init_node=init_node, ) self.puppet_bot_right = InterbotixManipulatorXS( robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False ) if setup_robots: self.setup_robots() self.recorder_left = robot_utils.Recorder("left", init_node=False) self.recorder_right = robot_utils.Recorder("right", init_node=False) self.image_recorder = robot_utils.ImageRecorder(init_node=False) self.gripper_command = JointSingleCommand(name="gripper") def setup_robots(self): robot_utils.setup_puppet_bot(self.puppet_bot_left) robot_utils.setup_puppet_bot(self.puppet_bot_right) def get_qpos(self): left_qpos_raw = self.recorder_left.qpos right_qpos_raw = self.recorder_right.qpos left_arm_qpos = left_qpos_raw[:6] right_arm_qpos = right_qpos_raw[:6] left_gripper_qpos = [ constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7]) ] # this is position not joint right_gripper_qpos = [ constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7]) ] # this is position not joint return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos]) def get_qvel(self): left_qvel_raw = self.recorder_left.qvel right_qvel_raw = self.recorder_right.qvel left_arm_qvel = left_qvel_raw[:6] right_arm_qvel = right_qvel_raw[:6] left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])] right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])] return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel]) def get_effort(self): left_effort_raw = self.recorder_left.effort right_effort_raw = self.recorder_right.effort left_robot_effort = left_effort_raw[:7] right_robot_effort = right_effort_raw[:7] return np.concatenate([left_robot_effort, right_robot_effort]) def get_images(self): return self.image_recorder.get_images() def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized): left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized) self.gripper_command.cmd = left_gripper_desired_joint self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command) right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN( right_gripper_desired_pos_normalized ) self.gripper_command.cmd = right_gripper_desired_joint self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command) def _reset_joints(self): robot_utils.move_arms( [self.puppet_bot_left, self.puppet_bot_right], [self._reset_position, self._reset_position], move_time=1 ) def _reset_gripper(self): """Set to position mode and do position resets: first close then open. Then change back to PWM mode NOTE: This diverges from the original Aloha code which first opens then closes the gripper. Pi internal aloha data was collected with the gripper starting in the open position. Leaving the grippers fully closed was also found to increase the frequency of motor faults. """ robot_utils.move_grippers( [self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1 ) robot_utils.move_grippers( [self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5 ) def get_observation(self): obs = collections.OrderedDict() obs["qpos"] = self.get_qpos() obs["qvel"] = self.get_qvel() obs["effort"] = self.get_effort() obs["images"] = self.get_images() return obs def get_reward(self): return 0 def reset(self, *, fake=False): if not fake: # Reboot puppet robot gripper motors self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True) self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True) self._reset_joints() self._reset_gripper() return dm_env.TimeStep( step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation() ) def step(self, action): state_len = int(len(action) / 2) left_action = action[:state_len] right_action = action[state_len:] self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False) self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False) self.set_gripper_pose(left_action[-1], right_action[-1]) time.sleep(constants.DT) return dm_env.TimeStep( step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation() ) def get_action(master_bot_left, master_bot_right): action = np.zeros(14) # 6 joint + 1 gripper, for two arms # Arm actions action[:6] = master_bot_left.dxl.joint_states.position[:6] action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6] # Gripper actions action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6]) action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6]) return action def make_real_env(init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True) -> RealEnv: return RealEnv(init_node, reset_position=reset_position, setup_robots=setup_robots) ================================================ FILE: examples/aloha_real/requirements.in ================================================ Pillow dm_control einops h5py matplotlib modern_robotics msgpack numpy>=1.22.4,<2.0.0 opencv-python packaging pexpect pyquaternion pyrealsense2 pyyaml requests rospkg tyro websockets ================================================ FILE: examples/aloha_real/requirements.txt ================================================ # This file was autogenerated by uv via the following command: # uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10 absl-py==2.1.0 # via # dm-control # dm-env # labmaze # mujoco catkin-pkg==1.0.0 # via rospkg certifi==2024.8.30 # via requests charset-normalizer==3.4.0 # via requests contourpy==1.1.1 # via matplotlib cycler==0.12.1 # via matplotlib distro==1.9.0 # via rospkg dm-control==1.0.23 # via -r examples/aloha_real/requirements.in dm-env==1.6 # via dm-control dm-tree==0.1.8 # via # dm-control # dm-env docstring-parser==0.16 # via tyro docutils==0.20.1 # via catkin-pkg einops==0.8.0 # via -r examples/aloha_real/requirements.in etils==1.3.0 # via mujoco fonttools==4.55.2 # via matplotlib glfw==2.8.0 # via # dm-control # mujoco h5py==3.11.0 # via -r examples/aloha_real/requirements.in idna==3.10 # via requests importlib-resources==6.4.5 # via etils kiwisolver==1.4.7 # via matplotlib labmaze==1.0.6 # via dm-control lxml==5.3.0 # via dm-control markdown-it-py==3.0.0 # via rich matplotlib==3.7.5 # via -r examples/aloha_real/requirements.in mdurl==0.1.2 # via markdown-it-py modern-robotics==1.1.1 # via -r examples/aloha_real/requirements.in msgpack==1.1.0 # via -r examples/aloha_real/requirements.in mujoco==3.2.3 # via dm-control numpy==1.24.4 # via # -r examples/aloha_real/requirements.in # contourpy # dm-control # dm-env # h5py # labmaze # matplotlib # modern-robotics # mujoco # opencv-python # pyquaternion # scipy opencv-python==4.10.0.84 # via -r examples/aloha_real/requirements.in packaging==24.2 # via # -r examples/aloha_real/requirements.in # matplotlib pexpect==4.9.0 # via -r examples/aloha_real/requirements.in pillow==10.4.0 # via # -r examples/aloha_real/requirements.in # matplotlib protobuf==5.29.1 # via dm-control ptyprocess==0.7.0 # via pexpect pygments==2.18.0 # via rich pyopengl==3.1.7 # via # dm-control # mujoco pyparsing==3.1.4 # via # catkin-pkg # dm-control # matplotlib pyquaternion==0.9.9 # via -r examples/aloha_real/requirements.in pyrealsense2==2.55.1.6486 # via -r examples/aloha_real/requirements.in python-dateutil==2.9.0.post0 # via # catkin-pkg # matplotlib pyyaml==6.0.2 # via # -r examples/aloha_real/requirements.in # rospkg requests==2.32.3 # via # -r examples/aloha_real/requirements.in # dm-control rich==13.9.4 # via tyro rospkg==1.5.1 # via -r examples/aloha_real/requirements.in scipy==1.10.1 # via dm-control setuptools==75.3.0 # via # catkin-pkg # dm-control # labmaze shtab==1.7.1 # via tyro six==1.17.0 # via python-dateutil tqdm==4.67.1 # via dm-control typeguard==4.4.0 # via tyro typing-extensions==4.12.2 # via # etils # rich # typeguard # tyro tyro==0.9.2 # via -r examples/aloha_real/requirements.in urllib3==2.2.3 # via requests websockets==14.1 # via -r examples/aloha_real/requirements.in zipp==3.20.2 # via etils ================================================ FILE: examples/aloha_real/robot_utils.py ================================================ # Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act). # ruff: noqa from collections import deque import datetime import json import time from aloha.msg import RGBGrayscaleImage from cv_bridge import CvBridge from interbotix_xs_msgs.msg import JointGroupCommand from interbotix_xs_msgs.msg import JointSingleCommand import numpy as np import rospy from sensor_msgs.msg import JointState from examples.aloha_real import constants class ImageRecorder: def __init__(self, init_node=True, is_debug=False): self.is_debug = is_debug self.bridge = CvBridge() self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"] if init_node: rospy.init_node("image_recorder", anonymous=True) for cam_name in self.camera_names: setattr(self, f"{cam_name}_rgb_image", None) setattr(self, f"{cam_name}_depth_image", None) setattr(self, f"{cam_name}_timestamp", 0.0) if cam_name == "cam_high": callback_func = self.image_cb_cam_high elif cam_name == "cam_low": callback_func = self.image_cb_cam_low elif cam_name == "cam_left_wrist": callback_func = self.image_cb_cam_left_wrist elif cam_name == "cam_right_wrist": callback_func = self.image_cb_cam_right_wrist else: raise NotImplementedError rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func) if self.is_debug: setattr(self, f"{cam_name}_timestamps", deque(maxlen=50)) self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names} time.sleep(0.5) def image_cb(self, cam_name, data): setattr( self, f"{cam_name}_rgb_image", self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"), ) # setattr( # self, # f"{cam_name}_depth_image", # self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"), # ) setattr( self, f"{cam_name}_timestamp", data.header.stamp.secs + data.header.stamp.nsecs * 1e-9, ) # setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs) # setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs) # cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image) if self.is_debug: getattr(self, f"{cam_name}_timestamps").append( data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9 ) def image_cb_cam_high(self, data): cam_name = "cam_high" return self.image_cb(cam_name, data) def image_cb_cam_low(self, data): cam_name = "cam_low" return self.image_cb(cam_name, data) def image_cb_cam_left_wrist(self, data): cam_name = "cam_left_wrist" return self.image_cb(cam_name, data) def image_cb_cam_right_wrist(self, data): cam_name = "cam_right_wrist" return self.image_cb(cam_name, data) def get_images(self): image_dict = {} for cam_name in self.camera_names: while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]: time.sleep(0.00001) rgb_image = getattr(self, f"{cam_name}_rgb_image") depth_image = getattr(self, f"{cam_name}_depth_image") self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp") image_dict[cam_name] = rgb_image image_dict[f"{cam_name}_depth"] = depth_image return image_dict def print_diagnostics(self): def dt_helper(l): l = np.array(l) diff = l[1:] - l[:-1] return np.mean(diff) for cam_name in self.camera_names: image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps")) print(f"{cam_name} {image_freq=:.2f}") print() class Recorder: def __init__(self, side, init_node=True, is_debug=False): self.secs = None self.nsecs = None self.qpos = None self.effort = None self.arm_command = None self.gripper_command = None self.is_debug = is_debug if init_node: rospy.init_node("recorder", anonymous=True) rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb) rospy.Subscriber( f"/puppet_{side}/commands/joint_group", JointGroupCommand, self.puppet_arm_commands_cb, ) rospy.Subscriber( f"/puppet_{side}/commands/joint_single", JointSingleCommand, self.puppet_gripper_commands_cb, ) if self.is_debug: self.joint_timestamps = deque(maxlen=50) self.arm_command_timestamps = deque(maxlen=50) self.gripper_command_timestamps = deque(maxlen=50) time.sleep(0.1) def puppet_state_cb(self, data): self.qpos = data.position self.qvel = data.velocity self.effort = data.effort self.data = data if self.is_debug: self.joint_timestamps.append(time.time()) def puppet_arm_commands_cb(self, data): self.arm_command = data.cmd if self.is_debug: self.arm_command_timestamps.append(time.time()) def puppet_gripper_commands_cb(self, data): self.gripper_command = data.cmd if self.is_debug: self.gripper_command_timestamps.append(time.time()) def print_diagnostics(self): def dt_helper(l): l = np.array(l) diff = l[1:] - l[:-1] return np.mean(diff) joint_freq = 1 / dt_helper(self.joint_timestamps) arm_command_freq = 1 / dt_helper(self.arm_command_timestamps) gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps) print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n") def get_arm_joint_positions(bot): return bot.arm.core.joint_states.position[:6] def get_arm_gripper_positions(bot): return bot.gripper.core.joint_states.position[6] def move_arms(bot_list, target_pose_list, move_time=1): num_steps = int(move_time / constants.DT) curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list] traj_list = [ np.linspace(curr_pose, target_pose, num_steps) for curr_pose, target_pose in zip(curr_pose_list, target_pose_list) ] for t in range(num_steps): for bot_id, bot in enumerate(bot_list): bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False) time.sleep(constants.DT) def move_grippers(bot_list, target_pose_list, move_time): print(f"Moving grippers to {target_pose_list=}") gripper_command = JointSingleCommand(name="gripper") num_steps = int(move_time / constants.DT) curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list] traj_list = [ np.linspace(curr_pose, target_pose, num_steps) for curr_pose, target_pose in zip(curr_pose_list, target_pose_list) ] with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f: for t in range(num_steps): d = {} for bot_id, bot in enumerate(bot_list): gripper_command.cmd = traj_list[bot_id][t] bot.gripper.core.pub_single.publish(gripper_command) d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]} f.write(json.dumps(d) + "\n") time.sleep(constants.DT) def setup_puppet_bot(bot): bot.dxl.robot_reboot_motors("single", "gripper", True) bot.dxl.robot_set_operating_modes("group", "arm", "position") bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position") torque_on(bot) def setup_master_bot(bot): bot.dxl.robot_set_operating_modes("group", "arm", "pwm") bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position") torque_off(bot) def set_standard_pid_gains(bot): bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800) bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0) def set_low_pid_gains(bot): bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100) bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0) def torque_off(bot): bot.dxl.robot_torque_enable("group", "arm", False) bot.dxl.robot_torque_enable("single", "gripper", False) def torque_on(bot): bot.dxl.robot_torque_enable("group", "arm", True) bot.dxl.robot_torque_enable("single", "gripper", True) # for DAgger def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right): print("\nSyncing!") # activate master arms torque_on(master_bot_left) torque_on(master_bot_right) # get puppet arm positions puppet_left_qpos = get_arm_joint_positions(puppet_bot_left) puppet_right_qpos = get_arm_joint_positions(puppet_bot_right) # get puppet gripper positions puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left) puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right) # move master arms to puppet positions move_arms( [master_bot_left, master_bot_right], [puppet_left_qpos, puppet_right_qpos], move_time=1, ) # move master grippers to puppet positions move_grippers( [master_bot_left, master_bot_right], [puppet_left_gripper, puppet_right_gripper], move_time=1, ) ================================================ FILE: examples/aloha_real/video_display.py ================================================ import matplotlib.pyplot as plt import numpy as np from openpi_client.runtime import subscriber as _subscriber from typing_extensions import override class VideoDisplay(_subscriber.Subscriber): """Displays video frames.""" def __init__(self) -> None: self._ax: plt.Axes | None = None self._plt_img: plt.Image | None = None @override def on_episode_start(self) -> None: plt.ion() self._ax = plt.subplot() self._plt_img = None @override def on_step(self, observation: dict, action: dict) -> None: assert self._ax is not None im = observation["image"][0] # [C, H, W] im = np.transpose(im, (1, 2, 0)) # [H, W, C] if self._plt_img is None: self._plt_img = self._ax.imshow(im) else: self._plt_img.set_data(im) plt.pause(0.001) @override def on_episode_end(self) -> None: plt.ioff() plt.close() ================================================ FILE: examples/aloha_sim/Dockerfile ================================================ # Dockerfile for the Aloha simulation environment. # Build the container: # docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile # Run the container: # docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78 COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/ RUN apt-get update && \ apt-get install -y \ libosmesa6-dev \ libgl1-mesa-glx \ libglew-dev \ libglfw3-dev \ libgles2-mesa-dev ENV MUJOCO_GL=egl WORKDIR /app # Copy from the cache instead of linking since it's a mounted volume ENV UV_LINK_MODE=copy # Write the virtual environment outside of the project directory so it doesn't # leak out of the container when we mount the application code. ENV UV_PROJECT_ENVIRONMENT=/.venv # Copy the requirements files so we can install dependencies. # The rest of the project is mounted as a volume, so we don't need to rebuild on changes. # This strategy is best for development-style usage. COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml # Install python dependencies. RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"] ================================================ FILE: examples/aloha_sim/README.md ================================================ # Run Aloha Sim ## With Docker ```bash export SERVER_ARGS="--env ALOHA_SIM" docker compose -f examples/aloha_sim/compose.yml up --build ``` ## Without Docker Terminal window 1: ```bash # Create virtual environment uv venv --python 3.10 examples/aloha_sim/.venv source examples/aloha_sim/.venv/bin/activate uv pip sync examples/aloha_sim/requirements.txt uv pip install -e packages/openpi-client # Run the simulation MUJOCO_GL=egl python examples/aloha_sim/main.py ``` Note: If you are seeing EGL errors, you may need to install the following dependencies: ```bash sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev ``` Terminal window 2: ```bash # Run the server uv run scripts/serve_policy.py --env ALOHA_SIM ``` ================================================ FILE: examples/aloha_sim/compose.yml ================================================ # Run with: # docker compose -f examples/aloha_sim/compose.yml up --build services: runtime: image: aloha_sim depends_on: - openpi_server build: context: ../.. dockerfile: examples/aloha_sim/Dockerfile init: true tty: true network_mode: host privileged: true volumes: - $PWD:/app - ../../data:/data openpi_server: image: openpi_server build: context: ../.. dockerfile: scripts/docker/serve_policy.Dockerfile init: true tty: true network_mode: host volumes: - $PWD:/app - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets environment: - SERVER_ARGS - OPENPI_DATA_HOME=/openpi_assets - IS_DOCKER=true # Comment out this block if not running on a machine with GPUs. deploy: resources: reservations: devices: - driver: nvidia count: 1 capabilities: [gpu] ================================================ FILE: examples/aloha_sim/env.py ================================================ import gym_aloha # noqa: F401 import gymnasium import numpy as np from openpi_client import image_tools from openpi_client.runtime import environment as _environment from typing_extensions import override class AlohaSimEnvironment(_environment.Environment): """An environment for an Aloha robot in simulation.""" def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None: np.random.seed(seed) self._rng = np.random.default_rng(seed) self._gym = gymnasium.make(task, obs_type=obs_type) self._last_obs = None self._done = True self._episode_reward = 0.0 @override def reset(self) -> None: gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1))) self._last_obs = self._convert_observation(gym_obs) # type: ignore self._done = False self._episode_reward = 0.0 @override def is_episode_complete(self) -> bool: return self._done @override def get_observation(self) -> dict: if self._last_obs is None: raise RuntimeError("Observation is not set. Call reset() first.") return self._last_obs # type: ignore @override def apply_action(self, action: dict) -> None: gym_obs, reward, terminated, truncated, info = self._gym.step(action["actions"]) self._last_obs = self._convert_observation(gym_obs) # type: ignore self._done = terminated or truncated self._episode_reward = max(self._episode_reward, reward) def _convert_observation(self, gym_obs: dict) -> dict: img = gym_obs["pixels"]["top"] img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224)) # Convert axis order from [H, W, C] --> [C, H, W] img = np.transpose(img, (2, 0, 1)) return { "state": gym_obs["agent_pos"], "images": {"cam_high": img}, } ================================================ FILE: examples/aloha_sim/main.py ================================================ import dataclasses import logging import pathlib import env as _env from openpi_client import action_chunk_broker from openpi_client import websocket_client_policy as _websocket_client_policy from openpi_client.runtime import runtime as _runtime from openpi_client.runtime.agents import policy_agent as _policy_agent import saver as _saver import tyro @dataclasses.dataclass class Args: out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos") task: str = "gym_aloha/AlohaTransferCube-v0" seed: int = 0 action_horizon: int = 10 host: str = "0.0.0.0" port: int = 8000 display: bool = False def main(args: Args) -> None: runtime = _runtime.Runtime( environment=_env.AlohaSimEnvironment( task=args.task, seed=args.seed, ), agent=_policy_agent.PolicyAgent( policy=action_chunk_broker.ActionChunkBroker( policy=_websocket_client_policy.WebsocketClientPolicy( host=args.host, port=args.port, ), action_horizon=args.action_horizon, ) ), subscribers=[ _saver.VideoSaver(args.out_dir), ], max_hz=50, ) runtime.run() if __name__ == "__main__": logging.basicConfig(level=logging.INFO, force=True) tyro.cli(main) ================================================ FILE: examples/aloha_sim/requirements.in ================================================ gym-aloha imageio matplotlib msgpack numpy>=1.22.4,<2.0.0 typing-extensions tyro websockets ================================================ FILE: examples/aloha_sim/requirements.txt ================================================ # This file was autogenerated by uv via the following command: # uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10 absl-py==2.1.0 # via # dm-control # dm-env # labmaze # mujoco certifi==2024.8.30 # via requests charset-normalizer==3.4.0 # via requests cloudpickle==3.1.0 # via gymnasium contourpy==1.3.1 # via matplotlib cycler==0.12.1 # via matplotlib dm-control==1.0.14 # via gym-aloha dm-env==1.6 # via dm-control dm-tree==0.1.8 # via # dm-control # dm-env docstring-parser==0.16 # via tyro farama-notifications==0.0.4 # via gymnasium fonttools==4.55.2 # via matplotlib glfw==2.8.0 # via # dm-control # mujoco gym-aloha==0.1.1 # via -r examples/aloha_sim/requirements.in gymnasium==1.0.0 # via gym-aloha idna==3.10 # via requests imageio==2.36.1 # via # -r examples/aloha_sim/requirements.in # gym-aloha imageio-ffmpeg==0.5.1 # via imageio kiwisolver==1.4.7 # via matplotlib labmaze==1.0.6 # via dm-control lxml==5.3.0 # via dm-control markdown-it-py==3.0.0 # via rich matplotlib==3.9.3 # via -r examples/aloha_sim/requirements.in mdurl==0.1.2 # via markdown-it-py msgpack==1.1.0 # via -r examples/aloha_sim/requirements.in mujoco==2.3.7 # via # dm-control # gym-aloha numpy==1.26.4 # via # -r examples/aloha_sim/requirements.in # contourpy # dm-control # dm-env # gymnasium # imageio # labmaze # matplotlib # mujoco # scipy packaging==24.2 # via matplotlib pillow==11.0.0 # via # imageio # matplotlib protobuf==5.29.1 # via dm-control psutil==6.1.0 # via imageio pygments==2.18.0 # via rich pyopengl==3.1.7 # via # dm-control # mujoco pyparsing==3.2.0 # via # dm-control # matplotlib python-dateutil==2.9.0.post0 # via matplotlib requests==2.32.3 # via dm-control rich==13.9.4 # via tyro scipy==1.14.1 # via dm-control setuptools==75.6.0 # via # dm-control # imageio-ffmpeg # labmaze shtab==1.7.1 # via tyro six==1.17.0 # via python-dateutil tqdm==4.67.1 # via dm-control typeguard==4.4.1 # via tyro typing-extensions==4.12.2 # via # -r examples/aloha_sim/requirements.in # gymnasium # rich # typeguard # tyro tyro==0.9.2 # via -r examples/aloha_sim/requirements.in urllib3==2.2.3 # via requests websockets==14.1 # via -r examples/aloha_sim/requirements.in ================================================ FILE: examples/aloha_sim/saver.py ================================================ import logging import pathlib import imageio import numpy as np from openpi_client.runtime import subscriber as _subscriber from typing_extensions import override class VideoSaver(_subscriber.Subscriber): """Saves episode data.""" def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None: out_dir.mkdir(parents=True, exist_ok=True) self._out_dir = out_dir self._images: list[np.ndarray] = [] self._subsample = subsample @override def on_episode_start(self) -> None: self._images = [] @override def on_step(self, observation: dict, action: dict) -> None: im = observation["images"]["cam_high"] # [C, H, W] im = np.transpose(im, (1, 2, 0)) # [H, W, C] self._images.append(im) @override def on_episode_end(self) -> None: existing = list(self._out_dir.glob("out_[0-9]*.mp4")) next_idx = max([int(p.stem.split("_")[1]) for p in existing], default=-1) + 1 out_path = self._out_dir / f"out_{next_idx}.mp4" logging.info(f"Saving video to {out_path}") imageio.mimwrite( out_path, [np.asarray(x) for x in self._images[:: self._subsample]], fps=50 // max(1, self._subsample), ) ================================================ FILE: examples/convert_jax_model_to_pytorch.py ================================================ #!/usr/bin/env python3 """ Load a JAX model and print all parameter keys, with optional conversion to PyTorch. This script loads a JAX model checkpoint using orbax and can either: 1. Print out all the parameter keys in a hierarchical structure for inspection 2. Convert the JAX model to PyTorch format using our PI0Pytorch model Usage: # Just inspect keys: python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only # Convert to PyTorch: python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output Example: # pi0_droid python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch # pi0_aloha_sim python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch # pi05_droid python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch """ import json import os import pathlib import shutil from typing import Literal from flax.nnx import traversals import numpy as np import orbax.checkpoint as ocp import safetensors import torch import tyro import openpi.models.gemma import openpi.models.model import openpi.models.pi0_config import openpi.models_pytorch.pi0_pytorch from openpi.training import utils import openpi.training.config as _config def slice_paligemma_state_dict(state_dict, config): """Convert PaliGemma JAX parameters to PyTorch format.""" suffix = "/value" if "img/embedding/kernel/value" in state_dict else "" # patch embeddings jax_key = f"img/embedding/kernel{suffix}" pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight" state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1) jax_key = f"img/embedding/bias{suffix}" pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias" state_dict[pytorch_key] = state_dict.pop(jax_key) # positional embeddings jax_key = f"img/pos_embedding{suffix}" pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight" state_dict[pytorch_key] = state_dict.pop(jax_key).reshape(-1, config.vision_config.hidden_size) # extract vision layers to be sliced at index 0. There are 27 layers in the base model. encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}") encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}") encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}") encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}") encoderblock_mlp_dense0_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}") encoderblock_mlp_dense0_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}") encoderblock_mlp_dense1_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}") encoderblock_mlp_dense1_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}") encoderblock_attention_0_key_kernel = state_dict.pop( f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}" ) encoderblock_attention_0_key_bias = state_dict.pop( f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}" ) encoderblock_attention_0_value_kernel = state_dict.pop( f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}" ) encoderblock_attention_0_value_bias = state_dict.pop( f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}" ) encoderblock_attention_0_query_kernel = state_dict.pop( f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}" ) encoderblock_attention_0_query_bias = state_dict.pop( f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}" ) encoderblock_attention_0_out_kernel = state_dict.pop( f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}" ) encoderblock_attention_0_out_bias = state_dict.pop( f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}" ) for i in range(config.vision_config.num_hidden_layers): state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight" ] = encoderblock_layernorm0_scale[i].transpose() state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias" ] = encoderblock_layernorm0_bias[i] state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight" ] = encoderblock_layernorm1_scale[i].transpose() state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias" ] = encoderblock_layernorm1_bias[i] state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight" ] = encoderblock_mlp_dense0_kernel[i].transpose() state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias" ] = encoderblock_mlp_dense0_bias[i] state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight" ] = encoderblock_mlp_dense1_kernel[i].transpose() state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias" ] = encoderblock_mlp_dense1_bias[i] state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight" ] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias" ] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight" ] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias" ] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight" ] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias" ] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight" ] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias" ] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) jax_key = f"img/Transformer/encoder_norm/scale{suffix}" pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight" state_dict[pytorch_key] = state_dict.pop(jax_key).transpose() jax_key = f"img/Transformer/encoder_norm/bias{suffix}" pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias" state_dict[pytorch_key] = state_dict.pop(jax_key) # multimodal projector jax_key = f"img/head/kernel{suffix}" pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight" state_dict[pytorch_key] = state_dict.pop(jax_key).transpose() jax_key = f"img/head/bias{suffix}" pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias" state_dict[pytorch_key] = state_dict.pop(jax_key) # text decoder (gemma) jax_key = f"llm/embedder/input_embedding{suffix}" pytorch_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" state_dict[pytorch_key] = state_dict.pop(jax_key) # pop the einsum attention + mlp representations llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}") llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}") llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}") llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}") llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}") llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}") llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}") for i in range(config.text_config.num_hidden_layers): q_proj_weight_reshaped = ( llm_attention_q_einsum[i] .transpose(0, 2, 1) .reshape( config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size ) ) state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight"] = ( q_proj_weight_reshaped ) k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] = ( k_proj_weight_reshaped ) v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] = ( v_proj_weight_reshaped ) o_proj_weight_reshaped = ( llm_attention_attn_vec_einsum[i] .transpose(2, 0, 1) .reshape( config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size ) ) state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = ( o_proj_weight_reshaped ) gate_proj_weight = llm_mlp_gating_einsum[i, 0] state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] = ( gate_proj_weight.transpose() ) up_proj_weight = llm_mlp_gating_einsum[i, 1] state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] = ( up_proj_weight.transpose() ) state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight"] = ( llm_mlp_linear[i].transpose() ) state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight"] = ( llm_input_layernorm[i] ) state_dict[ f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight" ] = llm_post_attention_layernorm[i] jax_key = f"llm/final_norm/scale{suffix}" pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight" state_dict[pytorch_key] = state_dict.pop(jax_key) expert_dict = {} final_state_dict = {} # Expert-related keys to extract (including pi05 Dense layer parameters) expert_keys = [ f"llm/final_norm_1/scale{suffix}", f"llm/final_norm_1/Dense_0/bias{suffix}", f"llm/final_norm_1/Dense_0/kernel{suffix}", f"llm/layers/attn/attn_vec_einsum_1/w{suffix}", f"llm/layers/attn/kv_einsum_1/w{suffix}", f"llm/layers/attn/q_einsum_1/w{suffix}", f"llm/layers/mlp_1/gating_einsum{suffix}", f"llm/layers/mlp_1/linear{suffix}", f"llm/layers/pre_attention_norm_1/scale{suffix}", f"llm/layers/pre_attention_norm_1/Dense_0/bias{suffix}", f"llm/layers/pre_attention_norm_1/Dense_0/kernel{suffix}", f"llm/layers/pre_ffw_norm_1/scale{suffix}", f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}", f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}", ] for key, value in state_dict.items(): if key not in expert_keys: final_state_dict[key] = torch.from_numpy(value) else: expert_dict[key] = value return final_state_dict, expert_dict def slice_gemma_state_dict(state_dict, config, *, num_expert, checkpoint_dir, pi05): """Convert Gemma JAX parameters to PyTorch format.""" # Add missing attributes to config if they don't exist if not hasattr(config, "vocab_size"): config.vocab_size = 257152 # PALIGEMMA_VOCAB_SIZE if not hasattr(config, "hidden_size"): config.hidden_size = config.width if not hasattr(config, "num_hidden_layers"): config.num_hidden_layers = config.depth if not hasattr(config, "num_attention_heads"): config.num_attention_heads = config.num_heads suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else "" llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}") llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}") llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}") llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}") llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}") # Check if we have Dense layers (for pi05/adaptive normalization) or scale layers (for regular pi0) if "pi05" in checkpoint_dir: # Pi05 with adaptive normalization llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}") llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}") llm_input_layernorm_kernel = state_dict.pop( f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}" ) llm_post_attention_layernorm_kernel = state_dict.pop( f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}" ) else: # Regular pi0 with standard RMSNorm llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}") llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}") for i in range(config.num_hidden_layers): q_proj_weight_reshaped = ( llm_attention_q_einsum[i] .transpose(0, 2, 1) .reshape(config.num_attention_heads * config.head_dim, config.hidden_size) ) state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = ( q_proj_weight_reshaped ) k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = ( k_proj_weight_reshaped ) v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = ( v_proj_weight_reshaped ) o_proj_weight_reshaped = ( llm_attention_attn_vec_einsum[i] .reshape(config.num_attention_heads * config.head_dim, config.hidden_size) .transpose(1, 0) ) state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = ( o_proj_weight_reshaped ) gate_proj_weight = llm_mlp_gating_einsum[i, 0] state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = ( gate_proj_weight.transpose() ) up_proj_weight = llm_mlp_gating_einsum[i, 1] state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = ( up_proj_weight.transpose() ) state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[ i ].transpose() if "pi05" in checkpoint_dir: # Pi05 with adaptive normalization - use Dense layer parameters directly state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] = ( llm_input_layernorm_bias[i] ) state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias"] = ( llm_post_attention_layernorm_bias[i] ) state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight"] = ( llm_input_layernorm_kernel[i].transpose() ) state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight"] = ( llm_post_attention_layernorm_kernel[i].transpose() ) else: # Regular pi0 with standard RMSNorm state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] = ( llm_input_layernorm[i] ) state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = ( llm_post_attention_layernorm[i] ) # Handle final norm layer if "pi05" in checkpoint_dir: # Pi05 with adaptive normalization - use Dense layer parameters directly final_norm_bias = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/bias{suffix}") final_norm_kernel = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/kernel{suffix}") state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.bias"] = final_norm_bias state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose() else: # Regular pi0 with standard RMSNorm state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop( f"llm/final_norm_{num_expert}/scale{suffix}" ) # state_dict["paligemma_with_expert.gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. final_state_dict = {} for key, value in state_dict.items(): if not isinstance(value, torch.Tensor): final_state_dict[key] = torch.from_numpy(value) else: final_state_dict[key] = value return final_state_dict def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str | None = None): """Load and process params by restoring via JAX model loader first. This respects dtype conversions that occur during model restore. """ # Use repository restore utility to load a pure dict of params (value suffix removed) params = openpi.models.model.restore_params( f"{checkpoint_dir}/params/", restore_type=np.ndarray, dtype=restore_precision ) return {"paligemma_params": traversals.flatten_mapping(params["PaliGemma"], sep="/"), "projection_params": params} def load_jax_model_and_print_keys(checkpoint_dir: str): """ Load JAX model from checkpoint and print all parameter keys. Args: checkpoint_dir: Path to the checkpoint directory """ checkpoint_dir = os.path.abspath(checkpoint_dir) if not checkpoint_dir.startswith("gs://") else checkpoint_dir # Initialize checkpointer checkpointer = ocp.PyTreeCheckpointer() metadata = checkpointer.metadata(f"{checkpoint_dir}/params") print(utils.array_tree_to_info(metadata)) def convert_pi0_checkpoint( checkpoint_dir: str, precision: str, output_path: str, model_config: openpi.models.pi0_config.Pi0Config ): """ Convert PI0 JAX checkpoint to PyTorch format. Args: checkpoint_dir: Path to the JAX checkpoint precision: Model precision (float32, bfloat16, float16) output_path: Path to save the converted PyTorch model model_config: Model config """ print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}") print(f"Model config: {model_config}") # Break down orbax ckpts by restoring via JAX to respect dtype initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32") # Process projection params if model_config.pi05: keys = [ "action_in_proj", "action_out_proj", "time_mlp_in", "time_mlp_out", ] else: keys = [ "state_proj", "action_in_proj", "action_out_proj", "action_time_mlp_in", "action_time_mlp_out", ] projection_params = {} for key in keys: kernel_params = initial_params["projection_params"][key]["kernel"] bias_params = initial_params["projection_params"][key]["bias"] if isinstance(kernel_params, dict): weight = kernel_params["value"] bias = bias_params["value"] else: weight = kernel_params bias = bias_params pytorch_weight_key = f"{key}.weight" pytorch_bias_key = f"{key}.bias" projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias)) # Create configs based on checkpoint path # All models use the same PaliGemma config structure class PaliGemmaConfig: def __init__(self): self.vision_config = type( "obj", (object,), { "hidden_size": 1152, "num_hidden_layers": 27, "num_attention_heads": 16, "intermediate_size": 4304, "patch_size": 14, "projection_dim": 2048, }, )() self.text_config = type( "obj", (object,), { "hidden_size": 2048, "num_hidden_layers": 18, "num_attention_heads": 8, "head_dim": 256, "intermediate_size": 16384, }, )() paligemma_config = PaliGemmaConfig() action_expert_config = openpi.models.gemma.get_config("gemma_300m") # Process PaliGemma weights paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config) # Process Gemma weights from expert_params gemma_params = slice_gemma_state_dict( expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir, pi05=model_config.pi05 ) # Instantiate model pi0_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_config) # Combine all parameters (no prefix needed for our model structure) all_params = {**paligemma_params, **gemma_params, **projection_params} # Load state dict pi0_model.load_state_dict(all_params, strict=False) if precision == "float32": pi0_model = pi0_model.to(torch.float32) elif precision == "bfloat16": pi0_model = pi0_model.to(torch.bfloat16) else: raise ValueError(f"Invalid precision: {precision}") # Save the converted model using safetensors os.makedirs(output_path, exist_ok=True) # Save model weights as SafeTensors using save_model to handle tied weights safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors")) # Copy assets folder if it exists assets_source = pathlib.Path(checkpoint_dir).parent / "assets" if assets_source.exists(): assets_dest = pathlib.Path(output_path) / "assets" if assets_dest.exists(): shutil.rmtree(assets_dest) shutil.copytree(assets_source, assets_dest) # Save config as JSON for reference config_dict = { "action_dim": model_config.action_dim, "action_horizon": model_config.action_horizon, "paligemma_variant": model_config.paligemma_variant, "action_expert_variant": model_config.action_expert_variant, "precision": precision, } with open(os.path.join(output_path, "config.json"), "w") as f: json.dump(config_dict, f, indent=2) print("Model conversion completed successfully!") print(f"Model saved to {output_path}") def main( checkpoint_dir: str, config_name: str, output_path: str | None = None, precision: Literal["float32", "bfloat16", "float16"] = "bfloat16", *, inspect_only: bool = False, ): """Load JAX model and optionally convert to PyTorch. Args: checkpoint_dir: Path to the JAX checkpoint directory output_path: Path to save converted PyTorch model (required for conversion) precision: Precision for model conversion inspect_only: Only inspect parameter keys, don't convert """ model_config = _config.get_config(config_name).model if not isinstance(model_config, openpi.models.pi0_config.Pi0Config): raise ValueError(f"Config {config_name} is not a Pi0Config") if inspect_only: load_jax_model_and_print_keys(checkpoint_dir) else: if not output_path: print("Error: --output_path is required for conversion. Use --inspect_only to only view keys.") return convert_pi0_checkpoint(checkpoint_dir, precision, output_path, model_config) if __name__ == "__main__": tyro.cli(main) ================================================ FILE: examples/droid/README.md ================================================ # DROID Policies in openpi We offer instructions for: - [Running inference for our best $pi_{0.5}$-DROID policy](./README.md#running-droid-inference) - [Running inference for other pre-trained DROID policies ($\pi_0$, $\pi_0$-FAST, ...)](./README.md#running-roboarena-baseline-policies) - [Pre-training *generalist* policies on the *full* DROID dataset](./README_train.md#training-on-droid) - [Fine-tuning expert $\pi_{0.5}$ on your custom DROID dataset](./README_train.md#fine-tuning-on-custom-droid-datasets) ## Running DROID Inference This example shows how to run the fine-tuned $\pi_{0.5}$-DROID model on the [DROID robot platform](https://github.com/droid-dataset/droid). Based on the [public RoboArena benchmark](https://robo-arena.github.io/leaderboard), this is currently our strongest generalist DROID policy. ### Step 1: Start a policy server Since the DROID control laptop does not have a powerful GPU, we will start a remote policy server on a different machine with a more powerful GPU and then query it from the DROID control laptop during inference. 1. On a machine with a powerful GPU (~NVIDIA 4090), clone and install the `openpi` repository following the instructions in the [README](https://github.com/Physical-Intelligence/openpi). 2. Start the OpenPI server via the following command: ```bash uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_droid --policy.dir=gs://openpi-assets/checkpoints/pi05_droid ``` You can also run the equivalent command below: ```bash uv run scripts/serve_policy.py --env=DROID ``` ### Step 2: Run the DROID robot 1. Make sure you have the most recent version of the DROID package installed on both the DROID control laptop and the NUC. 2. On the control laptop, activate your DROID conda environment. 3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`. 4. Install `tyro`, which we will use for command line parsing: `pip install tyro`. 5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory. 6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explorer` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with). 7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping ` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"]. ```bash python3 scripts/main.py --remote_host= --remote_port= --external_camera="left" ``` The script will ask you to enter a free-form language instruction for the robot to follow. Make sure to point the cameras at the scene you want the robot to interact with. You _do not_ need to carefully control camera angle, object positions, etc. The policy is fairly robust in our experience. Happy prompting! ## Troubleshooting | Issue | Solution | |-------|----------| | Cannot reach policy server | Make sure the server is running and the IP and port are correct. You can check that the server machine is reachable by running `ping ` from the DROID laptop. | | Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. | | Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). | | Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) | ## Running Other Policies We provide configs for running the baseline DROID policies from the [RoboArena](https://robo-arena.github.io/) paper. Simply run the commands below to start inference servers for the respective policies. Then follow the instructions above to run evaluation on the DROID robot. ``` # Train from pi0-FAST, using FAST tokenizer uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid # Train from pi0, using flow matching uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_droid # Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer. uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_binning_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_binning_droid # Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer). uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_droid # Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset). uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_specialist_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_specialist_droid # Trained from PaliGemma, using FSQ tokenizer. uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_vq_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_vq_droid # pi0-style diffusion / flow VLA, trained on DROID from PaliGemma. uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_diffusion_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_diffusion_droid ``` You can find the inference configs in [roboarena_config.py](../../src/openpi/training/misc/roboarena_config.py). ================================================ FILE: examples/droid/README_train.md ================================================ # Training on DROID Here we describe how to fine-tune the pi0.5 model on the *full* DROID dataset. This is an approximate open-source reproduction of the pi05-DROID training pipeline. (small differences in data loading and the used action space) -- For a tutorial on how to fine-tune your model with a smaller, custom dataset collected on the DROID platform, see below. In contrast to the rest of openpi, which uses LeRobot for data loading, we need to use RLDS as the data format for full DROID training (since at the moment LeRobot isn't scalable enough for larger datasets like DROID -- they are working on improving it though). Below, we provide instructions for updating your openpi environment for RLDS data loading and where to download the DROID dataset. ## Install We need a few additional dependencies for RLDS data loading. Run: ```bash uv sync --group rlds ``` ## Download DROID dataset You can download the DROID dataset with the following command (after installing the `gsutil` google cloud CLI): ``` gsutil -m cp -r gs://gresearch/robotics/droid/1.0.1 /droid/1.0.1 ``` Note that downloading version 1.0.1 is important (not v1.0.0): it contains the complete set of language annotations (~75k episodes) while v1.0.0 only has annotations for 30k episodes. If for some reason you would like to use another version, modify the line `version="1.0.1"` in the `DroidRldsDataset` object [here](src/openpi/training/droid_rlds_dataset.py). You will need 1.8TB of disk storage to download the DROID RLDS dataset. ## Run First, change the `rlds_data_dir` path in your `TrainConfig` to the directory that you downloaded the `droid` dataset into (see [src/openpi/training/config.py](src/openpi/training/config.py)). Then, compute normalization statistics (this will take ~10 minutes): ```bash uv run --group rlds scripts/compute_norm_stats.py --config-name pi05_full_droid_finetune --max-frames 10_000_000 ``` Run training: ```bash XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run --group rlds scripts/train.py pi05_full_droid_finetune --exp-name=my_experiment --overwrite ``` **Note**: The original pi0.5-DROID model was trained with joint velocity actions. Joint velocity actions are not compatible with simulated evaluation environments (much harder to simulate). Thus, we do not recommend training with joint velocity actions and instead use joint position actions here. ## Compute Requirements Our DROID training config requires approximately 2 days on 8x H100 GPUs for convergence (100k iterations, bs256, approx. 1 epoch). If you start from PaliGemma instead of pi0 initialization, plan with ~5 days on 8x H100s (240k iterations, i.e. 3 epochs). We have experimented with LoRA for cheaper finetuning, but haven't found the policies to perform well so far. ## Data Filtering Like any diverse real-robot dataset, the DROID dataset isn't perfectly "clean" and we have found data filtering to significantly improve policy performance. Concretely, the DROID dataset contains many *idle* timesteps in which the robot does not move (in part due to the VR teleoperation interface that was used during data collection, we will not go into too much detail here). Appropriate filtering of these idle transitions can improve policy performance. By default, our openpi training recipe implements the same idle filter used to train all pi-DROID models. We implement it by pre-computing which dataset indices to sample during training. You can check [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) for how we compute these indices. Roughly speaking, we filter any time steps for which the next chunk of actions would be largely idle. During training, our code automatically pulls our pre-computed list of indices from cloud storage and applies them. If you want to modify the idle filter / create your custom sampling logic, you can modify our script to generate a new index list and provide it via the `filter_dict_path=""` argument in [src/openpi/training/config.py](src/openpi/training/config.py). **Note**: our list of filtering indices is only valid for the `droid/1.0.1` dataset mentioned in the download section above, and will not provide valid filtering for any other version of the DROID dataset, so make sure you download the dataset above! If you have a custom DROID version, you can rerun the [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) script to generate a new list of sampling indices. ## RoboArena Consider submitting your DROID policies to the [RoboArena benchmark](https://robo-arena.github.io/), which allows you to evaluate your policies on diverse tasks & scenes, **in the real world**! :) If you have questions about RoboArena, please email [karl.pertsch@gmail.com](mailto:karl.pertsch@gmail.com). # Fine-Tuning on Custom DROID Datasets Here we describe how to fine-tune a model on a custom (smaller) dataset collected on the DROID platform. Like for other datasets, we will first convert the custom DROID dataset to LeRobot and then fine-tune a model (pi05-droid) on it. Note: We use LeRobot here, since we assume the custom DROID fine-tuning dataset to be relatively small (<10s of hours). For larger datasets (like the full DROID dataset) we recommend using RLDS for it's better efficiency (see the example above). ## Step 1: Converting your custom DROID dataset to LeRobot We will use a small subset of the real DROID dataset for this example. This is a subset of just 30 demonstrations -- we assume that you will use your own dataset instead, but here is the command to download our subset (1.6GB): ``` gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04 ``` We will also download the language annotations for the DROID dataset so we can pair our demonstrations with language instructions. Again, for your own data you can manually enter your language instructions and don't need to download our annotations. To download the DROID language annotations (12MB), run: ``` gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/aggregated-annotations-030724.json ``` For your own dataset, make sure that each episode's directory contains a folder called `recordings/MP4` -- if not, you need to first run the MP4 video extraction (from SVO files) using the script [here](https://github.com/droid-dataset/droid/blob/main/scripts/convert/svo_to_mp4.py). Now, we will use the `convert_droid_to_lerobot.py` script to create a LeRobot version of this dataset (takes <5min for the 30 demonstrations): ``` uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir ``` ## Step 2: Run fine-tuning with your custom dataset Now we can run fine-tuning with our converted custom dataset. We provide an example config for fine-tuning `pi05_droid` on the custom dataset we created. You can modify the config easily to work with other base models, or use your custom DROID dataset in `config.py` (seach for `pi05_droid_finetune`). To launch training: ``` uv run scripts/train.py pi05_droid_finetune --exp-name=my_experiment --overwrite ``` Once trained, you can follow the instructions in [`examples/droid/README.md`](examples/droid/README.md) to serve the policy and run it on the robot. ================================================ FILE: examples/droid/compute_droid_nonidle_ranges.py ================================================ """ Iterates through the DROID dataset and creates a json mapping from episode unique IDs to ranges of time steps that should be sampled during training (all others are filtered out). Filtering logic: We look for ranges of consecutive steps that contain at most min_idle_len consecutive idle frames (default to 7 -- as most DROID action-chunking policies run the first 8 actions generated in each chunk, filtering this way means the policy will not get stuck outputting stationary actions). Additionally, we also only keep non-idle ranges of length at least min_non_idle_len (default to 16 frames = ~1 second), while also removing the last filter_last_n_in_ranges frames from the end of each range (as those all correspond to action chunks with many idle actions). This leaves us with trajectory segments consisting of contiguous, significant movement. Training on this filtered set yields policies that output fewer stationary actions (i.e., get "stuck" in states less). """ import json import os from pathlib import Path import numpy as np import tensorflow as tf import tensorflow_datasets as tfds from tqdm import tqdm os.environ["CUDA_VISIBLE_DEVICES"] = "" # Set to the GPU you want to use, or leave empty for CPU builder = tfds.builder_from_directory( # path to the `droid` directory (not its parent) builder_dir="", ) ds = builder.as_dataset(split="train", shuffle_files=False) tf.data.experimental.ignore_errors(ds) keep_ranges_path = "" min_idle_len = 7 # If more than this number of consecutive idle frames, filter all of them out min_non_idle_len = 16 # If fewer than this number of consecutive non-idle frames, filter all of them out filter_last_n_in_ranges = 10 # When using a filter dict, remove this many frames from the end of each range keep_ranges_map = {} if Path(keep_ranges_path).exists(): with Path(keep_ranges_path).open("r") as f: keep_ranges_map = json.load(f) print(f"Resuming from {len(keep_ranges_map)} episodes already processed") for ep_idx, ep in enumerate(tqdm(ds)): recording_folderpath = ep["episode_metadata"]["recording_folderpath"].numpy().decode() file_path = ep["episode_metadata"]["file_path"].numpy().decode() key = f"{recording_folderpath}--{file_path}" if key in keep_ranges_map: continue joint_velocities = [step["action_dict"]["joint_velocity"].numpy() for step in ep["steps"]] joint_velocities = np.array(joint_velocities) is_idle_array = np.hstack( [np.array([False]), np.all(np.abs(joint_velocities[1:] - joint_velocities[:-1]) < 1e-3, axis=1)] ) # Find what steps go from idle to non-idle and vice-versa is_idle_padded = np.concatenate( [[False], is_idle_array, [False]] ) # Start and end with False, so idle at first step is a start of motion is_idle_diff = np.diff(is_idle_padded.astype(int)) is_idle_true_starts = np.where(is_idle_diff == 1)[0] # +1 transitions --> going from idle to non-idle is_idle_true_ends = np.where(is_idle_diff == -1)[0] # -1 transitions --> going from non-idle to idle # Find which steps correspond to idle segments of length at least min_idle_len true_segment_masks = (is_idle_true_ends - is_idle_true_starts) >= min_idle_len is_idle_true_starts = is_idle_true_starts[true_segment_masks] is_idle_true_ends = is_idle_true_ends[true_segment_masks] keep_mask = np.ones(len(joint_velocities), dtype=bool) for start, end in zip(is_idle_true_starts, is_idle_true_ends, strict=True): keep_mask[start:end] = False # Get all non-idle ranges of at least 16 # Same logic as above, but for keep_mask, allowing us to filter out contiguous ranges of length < min_non_idle_len keep_padded = np.concatenate([[False], keep_mask, [False]]) keep_diff = np.diff(keep_padded.astype(int)) keep_true_starts = np.where(keep_diff == 1)[0] # +1 transitions --> going from filter out to keep keep_true_ends = np.where(keep_diff == -1)[0] # -1 transitions --> going from keep to filter out # Find which steps correspond to non-idle segments of length at least min_non_idle_len true_segment_masks = (keep_true_ends - keep_true_starts) >= min_non_idle_len keep_true_starts = keep_true_starts[true_segment_masks] keep_true_ends = keep_true_ends[true_segment_masks] # Add mapping from episode unique ID key to list of non-idle ranges to keep keep_ranges_map[key] = [] for start, end in zip(keep_true_starts, keep_true_ends, strict=True): keep_ranges_map[key].append((int(start), int(end) - filter_last_n_in_ranges)) if ep_idx % 1000 == 0: with Path(keep_ranges_path).open("w") as f: json.dump(keep_ranges_map, f) print("Done!") with Path(keep_ranges_path).open("w") as f: json.dump(keep_ranges_map, f) ================================================ FILE: examples/droid/convert_droid_data_to_lerobot.py ================================================ """ Minimal example script for converting a dataset collected on the DROID platform to LeRobot format. Usage: uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data If you want to push your dataset to the Hugging Face Hub, you can use the following command: uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub The resulting dataset will get saved to the $LEROBOT_HOME directory. """ from collections import defaultdict import copy import glob import json from pathlib import Path import shutil import cv2 import h5py from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME from lerobot.common.datasets.lerobot_dataset import LeRobotDataset import numpy as np from PIL import Image from tqdm import tqdm import tyro REPO_NAME = "your_hf_username/my_droid_dataset" # Name of the output dataset, also used for the Hugging Face Hub def resize_image(image, size): image = Image.fromarray(image) return np.array(image.resize(size, resample=Image.BICUBIC)) def main(data_dir: str, *, push_to_hub: bool = False): # Clean up any existing dataset in the output directory output_path = HF_LEROBOT_HOME / REPO_NAME if output_path.exists(): shutil.rmtree(output_path) data_dir = Path(data_dir) # Create LeRobot dataset, define features to store # We will follow the DROID data naming conventions here. # LeRobot assumes that dtype of image data is `image` dataset = LeRobotDataset.create( repo_id=REPO_NAME, robot_type="panda", fps=15, # DROID data is typically recorded at 15fps features={ # We call this "left" since we will only use the left stereo camera (following DROID RLDS convention) "exterior_image_1_left": { "dtype": "image", "shape": (180, 320, 3), # This is the resolution used in the DROID RLDS dataset "names": ["height", "width", "channel"], }, "exterior_image_2_left": { "dtype": "image", "shape": (180, 320, 3), "names": ["height", "width", "channel"], }, "wrist_image_left": { "dtype": "image", "shape": (180, 320, 3), "names": ["height", "width", "channel"], }, "joint_position": { "dtype": "float32", "shape": (7,), "names": ["joint_position"], }, "gripper_position": { "dtype": "float32", "shape": (1,), "names": ["gripper_position"], }, "actions": { "dtype": "float32", "shape": (8,), # We will use joint *velocity* actions here (7D) + gripper position (1D) "names": ["actions"], }, }, image_writer_threads=10, image_writer_processes=5, ) # Load language annotations # Note: we load the DROID language annotations for this example, but you can manually define them for your own data with (data_dir / "aggregated-annotations-030724.json").open() as f: language_annotations = json.load(f) # Loop over raw DROID fine-tuning datasets and write episodes to the LeRobot dataset # We assume the following directory structure: # RAW_DROID_PATH/ # - <...>/ # - recordings/ # - MP4/ # - .mp4 # single-view video of left stereo pair camera # - trajectory.hdf5 # - <...>/ episode_paths = list(data_dir.glob("**/trajectory.h5")) print(f"Found {len(episode_paths)} episodes for conversion") # We will loop over each dataset_name and write episodes to the LeRobot dataset for episode_path in tqdm(episode_paths, desc="Converting episodes"): # Load raw data recording_folderpath = episode_path.parent / "recordings" / "MP4" trajectory = load_trajectory(str(episode_path), recording_folderpath=str(recording_folderpath)) # To load the language instruction, we need to parse out the episode_id from the metadata file # Again, you can modify this step for your own data, to load your own language instructions metadata_filepath = next(iter(episode_path.parent.glob("metadata_*.json"))) episode_id = metadata_filepath.name.split(".")[0].split("_")[-1] language_instruction = language_annotations.get(episode_id, {"language_instruction1": "Do something"})[ "language_instruction1" ] print(f"Converting episode with language instruction: {language_instruction}") # Write to LeRobot dataset for step in trajectory: camera_type_dict = step["observation"]["camera_type"] wrist_ids = [k for k, v in camera_type_dict.items() if v == 0] exterior_ids = [k for k, v in camera_type_dict.items() if v != 0] dataset.add_frame( { # Note: need to flip BGR --> RGB for loaded images "exterior_image_1_left": resize_image( step["observation"]["image"][exterior_ids[0]][..., ::-1], (320, 180) ), "exterior_image_2_left": resize_image( step["observation"]["image"][exterior_ids[1]][..., ::-1], (320, 180) ), "wrist_image_left": resize_image(step["observation"]["image"][wrist_ids[0]][..., ::-1], (320, 180)), "joint_position": np.asarray( step["observation"]["robot_state"]["joint_positions"], dtype=np.float32 ), "gripper_position": np.asarray( step["observation"]["robot_state"]["gripper_position"][None], dtype=np.float32 ), # Important: we use joint velocity actions here since pi05-droid was pre-trained on joint velocity actions "actions": np.concatenate( [step["action"]["joint_velocity"], step["action"]["gripper_position"][None]], dtype=np.float32 ), "task": language_instruction, } ) dataset.save_episode() # Optionally push to the Hugging Face Hub if push_to_hub: dataset.push_to_hub( tags=["libero", "panda", "rlds"], private=False, push_videos=True, license="apache-2.0", ) ########################################################################################################## ################ The rest of this file are functions to parse the raw DROID data ######################### ################ You don't need to worry about understanding this part ######################### ################ It was copied from here: https://github.com/JonathanYang0127/r2d2_rlds_dataset_builder/blob/parallel_convert/r2_d2/r2_d2.py ########################################################################################################## camera_type_dict = { "hand_camera_id": 0, "varied_camera_1_id": 1, "varied_camera_2_id": 1, } camera_type_to_string_dict = { 0: "hand_camera", 1: "varied_camera", 2: "fixed_camera", } def get_camera_type(cam_id): if cam_id not in camera_type_dict: return None type_int = camera_type_dict[cam_id] return camera_type_to_string_dict[type_int] class MP4Reader: def __init__(self, filepath, serial_number): # Save Parameters # self.serial_number = serial_number self._index = 0 # Open Video Reader # self._mp4_reader = cv2.VideoCapture(filepath) if not self._mp4_reader.isOpened(): raise RuntimeError("Corrupted MP4 File") def set_reading_parameters( self, image=True, # noqa: FBT002 concatenate_images=False, # noqa: FBT002 resolution=(0, 0), resize_func=None, ): # Save Parameters # self.image = image self.concatenate_images = concatenate_images self.resolution = resolution self.resize_func = cv2.resize self.skip_reading = not image if self.skip_reading: return def get_frame_resolution(self): width = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH) height = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT) return (width, height) def get_frame_count(self): if self.skip_reading: return 0 return int(self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT)) def set_frame_index(self, index): if self.skip_reading: return if index < self._index: self._mp4_reader.set(cv2.CAP_PROP_POS_FRAMES, index - 1) self._index = index while self._index < index: self.read_camera(ignore_data=True) def _process_frame(self, frame): frame = copy.deepcopy(frame) if self.resolution == (0, 0): return frame return self.resize_func(frame, self.resolution) def read_camera(self, ignore_data=False, correct_timestamp=None): # noqa: FBT002 # Skip if Read Unnecessary # if self.skip_reading: return {} # Read Camera # success, frame = self._mp4_reader.read() self._index += 1 if not success: return None if ignore_data: return None # Return Data # data_dict = {} if self.concatenate_images or "stereo" not in self.serial_number: data_dict["image"] = {self.serial_number: self._process_frame(frame)} else: single_width = frame.shape[1] // 2 data_dict["image"] = { self.serial_number + "_left": self._process_frame(frame[:, :single_width, :]), self.serial_number + "_right": self._process_frame(frame[:, single_width:, :]), } return data_dict def disable_camera(self): if hasattr(self, "_mp4_reader"): self._mp4_reader.release() class RecordedMultiCameraWrapper: def __init__(self, recording_folderpath, camera_kwargs={}): # noqa: B006 # Save Camera Info # self.camera_kwargs = camera_kwargs # Open Camera Readers # mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4") all_filepaths = mp4_filepaths self.camera_dict = {} for f in all_filepaths: serial_number = f.split("/")[-1][:-4] cam_type = get_camera_type(serial_number) camera_kwargs.get(cam_type, {}) if f.endswith(".mp4"): Reader = MP4Reader # noqa: N806 else: raise ValueError self.camera_dict[serial_number] = Reader(f, serial_number) def read_cameras(self, index=None, camera_type_dict={}, timestamp_dict={}): # noqa: B006 full_obs_dict = defaultdict(dict) # Read Cameras In Randomized Order # all_cam_ids = list(self.camera_dict.keys()) # random.shuffle(all_cam_ids) for cam_id in all_cam_ids: if "stereo" in cam_id: continue try: cam_type = camera_type_dict[cam_id] except KeyError: print(f"{self.camera_dict} -- {camera_type_dict}") raise ValueError(f"Camera type {cam_id} not found in camera_type_dict") # noqa: B904 curr_cam_kwargs = self.camera_kwargs.get(cam_type, {}) self.camera_dict[cam_id].set_reading_parameters(**curr_cam_kwargs) timestamp = timestamp_dict.get(cam_id + "_frame_received", None) if index is not None: self.camera_dict[cam_id].set_frame_index(index) data_dict = self.camera_dict[cam_id].read_camera(correct_timestamp=timestamp) # Process Returned Data # if data_dict is None: return None for key in data_dict: full_obs_dict[key].update(data_dict[key]) return full_obs_dict def get_hdf5_length(hdf5_file, keys_to_ignore=[]): # noqa: B006 length = None for key in hdf5_file: if key in keys_to_ignore: continue curr_data = hdf5_file[key] if isinstance(curr_data, h5py.Group): curr_length = get_hdf5_length(curr_data, keys_to_ignore=keys_to_ignore) elif isinstance(curr_data, h5py.Dataset): curr_length = len(curr_data) else: raise ValueError if length is None: length = curr_length assert curr_length == length return length def load_hdf5_to_dict(hdf5_file, index, keys_to_ignore=[]): # noqa: B006 data_dict = {} for key in hdf5_file: if key in keys_to_ignore: continue curr_data = hdf5_file[key] if isinstance(curr_data, h5py.Group): data_dict[key] = load_hdf5_to_dict(curr_data, index, keys_to_ignore=keys_to_ignore) elif isinstance(curr_data, h5py.Dataset): data_dict[key] = curr_data[index] else: raise ValueError return data_dict class TrajectoryReader: def __init__(self, filepath, read_images=True): # noqa: FBT002 self._hdf5_file = h5py.File(filepath, "r") is_video_folder = "observations/videos" in self._hdf5_file self._read_images = read_images and is_video_folder self._length = get_hdf5_length(self._hdf5_file) self._video_readers = {} self._index = 0 def length(self): return self._length def read_timestep(self, index=None, keys_to_ignore=[]): # noqa: B006 # Make Sure We Read Within Range # if index is None: index = self._index else: assert not self._read_images self._index = index assert index < self._length # Load Low Dimensional Data # keys_to_ignore = [*keys_to_ignore.copy(), "videos"] timestep = load_hdf5_to_dict(self._hdf5_file, self._index, keys_to_ignore=keys_to_ignore) # Increment Read Index # self._index += 1 # Return Timestep # return timestep def close(self): self._hdf5_file.close() def load_trajectory( filepath=None, read_cameras=True, # noqa: FBT002 recording_folderpath=None, camera_kwargs={}, # noqa: B006 remove_skipped_steps=False, # noqa: FBT002 num_samples_per_traj=None, num_samples_per_traj_coeff=1.5, ): read_recording_folderpath = read_cameras and (recording_folderpath is not None) traj_reader = TrajectoryReader(filepath) if read_recording_folderpath: camera_reader = RecordedMultiCameraWrapper(recording_folderpath, camera_kwargs) horizon = traj_reader.length() timestep_list = [] # Choose Timesteps To Save # if num_samples_per_traj: num_to_save = num_samples_per_traj if remove_skipped_steps: num_to_save = int(num_to_save * num_samples_per_traj_coeff) max_size = min(num_to_save, horizon) indices_to_save = np.sort(np.random.choice(horizon, size=max_size, replace=False)) else: indices_to_save = np.arange(horizon) # Iterate Over Trajectory # for i in indices_to_save: # Get HDF5 Data # timestep = traj_reader.read_timestep(index=i) # If Applicable, Get Recorded Data # if read_recording_folderpath: timestamp_dict = timestep["observation"]["timestamp"]["cameras"] camera_type_dict = { k: camera_type_to_string_dict[v] for k, v in timestep["observation"]["camera_type"].items() } camera_obs = camera_reader.read_cameras( index=i, camera_type_dict=camera_type_dict, timestamp_dict=timestamp_dict ) camera_failed = camera_obs is None # Add Data To Timestep If Successful # if camera_failed: break timestep["observation"].update(camera_obs) # Filter Steps # step_skipped = not timestep["observation"]["controller_info"].get("movement_enabled", True) delete_skipped_step = step_skipped and remove_skipped_steps # Save Filtered Timesteps # if delete_skipped_step: del timestep else: timestep_list.append(timestep) # Remove Extra Transitions # timestep_list = np.array(timestep_list) if (num_samples_per_traj is not None) and (len(timestep_list) > num_samples_per_traj): ind_to_keep = np.random.choice(len(timestep_list), size=num_samples_per_traj, replace=False) timestep_list = timestep_list[ind_to_keep] # Close Readers # traj_reader.close() # Return Data # return timestep_list if __name__ == "__main__": tyro.cli(main) ================================================ FILE: examples/droid/main.py ================================================ # ruff: noqa import contextlib import dataclasses import datetime import faulthandler import os import signal import time from moviepy.editor import ImageSequenceClip import numpy as np from openpi_client import image_tools from openpi_client import websocket_client_policy import pandas as pd from PIL import Image from droid.robot_env import RobotEnv import tqdm import tyro faulthandler.enable() # DROID data collection frequency -- we slow down execution to match this frequency DROID_CONTROL_FREQUENCY = 15 @dataclasses.dataclass class Args: # Hardware parameters left_camera_id: str = "" # e.g., "24259877" right_camera_id: str = "" # e.g., "24514023" wrist_camera_id: str = "" # e.g., "13062452" # Policy parameters external_camera: str | None = ( None # which external camera should be fed to the policy, choose from ["left", "right"] ) # Rollout parameters max_timesteps: int = 600 # How many actions to execute from a predicted action chunk before querying policy server again # 8 is usually a good default (equals 0.5 seconds of action execution). open_loop_horizon: int = 8 # Remote server parameters remote_host: str = "0.0.0.0" # point this to the IP address of the policy server, e.g., "192.168.1.100" remote_port: int = ( 8000 # point this to the port of the policy server, default server port for openpi servers is 8000 ) # We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is # waiting for a new action chunk, it will raise an exception and the server connection dies. # This context manager temporarily prevents Ctrl+C and delays it after the server call is complete. @contextlib.contextmanager def prevent_keyboard_interrupt(): """Temporarily prevent keyboard interrupts by delaying them until after the protected code.""" interrupted = False original_handler = signal.getsignal(signal.SIGINT) def handler(signum, frame): nonlocal interrupted interrupted = True signal.signal(signal.SIGINT, handler) try: yield finally: signal.signal(signal.SIGINT, original_handler) if interrupted: raise KeyboardInterrupt def main(args: Args): # Make sure external camera is specified by user -- we only use one external camera for the policy assert ( args.external_camera is not None and args.external_camera in ["left", "right"] ), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}" # Initialize the Panda environment. Using joint velocity action space and gripper position action space is very important. env = RobotEnv(action_space="joint_velocity", gripper_action_space="position") print("Created the droid env!") # Connect to the policy server policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port) df = pd.DataFrame(columns=["success", "duration", "video_filename"]) while True: instruction = input("Enter instruction: ") # Rollout parameters actions_from_chunk_completed = 0 pred_action_chunk = None # Prepare to save video of rollout timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S") video = [] bar = tqdm.tqdm(range(args.max_timesteps)) print("Running rollout... press Ctrl+C to stop early.") for t_step in bar: start_time = time.time() try: # Get the current observation curr_obs = _extract_observation( args, env.get_observation(), # Save the first observation to disk save_to_disk=t_step == 0, ) video.append(curr_obs[f"{args.external_camera}_image"]) # Send websocket request to policy server if it's time to predict a new chunk if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon: actions_from_chunk_completed = 0 # We resize images on the robot laptop to minimize the amount of data sent to the policy server # and improve latency. request_data = { "observation/exterior_image_1_left": image_tools.resize_with_pad( curr_obs[f"{args.external_camera}_image"], 224, 224 ), "observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224), "observation/joint_position": curr_obs["joint_position"], "observation/gripper_position": curr_obs["gripper_position"], "prompt": instruction, } # Wrap the server call in a context manager to prevent Ctrl+C from interrupting it # Ctrl+C will be handled after the server call is complete with prevent_keyboard_interrupt(): # this returns action chunk [10, 8] of 10 joint velocity actions (7) + gripper position (1) pred_action_chunk = policy_client.infer(request_data)["actions"] assert pred_action_chunk.shape == (10, 8) # Select current action to execute from chunk action = pred_action_chunk[actions_from_chunk_completed] actions_from_chunk_completed += 1 # Binarize gripper action if action[-1].item() > 0.5: # action[-1] = 1.0 action = np.concatenate([action[:-1], np.ones((1,))]) else: # action[-1] = 0.0 action = np.concatenate([action[:-1], np.zeros((1,))]) # clip all dimensions of action to [-1, 1] action = np.clip(action, -1, 1) env.step(action) # Sleep to match DROID data collection frequency elapsed_time = time.time() - start_time if elapsed_time < 1 / DROID_CONTROL_FREQUENCY: time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time) except KeyboardInterrupt: break video = np.stack(video) save_filename = "video_" + timestamp ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + ".mp4", codec="libx264") success: str | float | None = None while not isinstance(success, float): success = input( "Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec" ) if success == "y": success = 1.0 elif success == "n": success = 0.0 success = float(success) / 100 if not (0 <= success <= 1): print(f"Success must be a number in [0, 100] but got: {success * 100}") df = df.append( { "success": success, "duration": t_step, "video_filename": save_filename, }, ignore_index=True, ) if input("Do one more eval? (enter y or n) ").lower() != "y": break env.reset() os.makedirs("results", exist_ok=True) timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y") csv_filename = os.path.join("results", f"eval_{timestamp}.csv") df.to_csv(csv_filename) print(f"Results saved to {csv_filename}") def _extract_observation(args: Args, obs_dict, *, save_to_disk=False): image_observations = obs_dict["image"] left_image, right_image, wrist_image = None, None, None for key in image_observations: # Note the "left" below refers to the left camera in the stereo pair. # The model is only trained on left stereo cams, so we only feed those. if args.left_camera_id in key and "left" in key: left_image = image_observations[key] elif args.right_camera_id in key and "left" in key: right_image = image_observations[key] elif args.wrist_camera_id in key and "left" in key: wrist_image = image_observations[key] # Drop the alpha dimension left_image = left_image[..., :3] right_image = right_image[..., :3] wrist_image = wrist_image[..., :3] # Convert to RGB left_image = left_image[..., ::-1] right_image = right_image[..., ::-1] wrist_image = wrist_image[..., ::-1] # In addition to image observations, also capture the proprioceptive state robot_state = obs_dict["robot_state"] cartesian_position = np.array(robot_state["cartesian_position"]) joint_position = np.array(robot_state["joint_positions"]) gripper_position = np.array([robot_state["gripper_position"]]) # Save the images to disk so that they can be viewed live while the robot is running # Create one combined image to make live viewing easy if save_to_disk: combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1) combined_image = Image.fromarray(combined_image) combined_image.save("robot_camera_views.png") return { "left_image": left_image, "right_image": right_image, "wrist_image": wrist_image, "cartesian_position": cartesian_position, "joint_position": joint_position, "gripper_position": gripper_position, } if __name__ == "__main__": args: Args = tyro.cli(Args) main(args) ================================================ FILE: examples/inference.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import dataclasses\n", "\n", "import jax\n", "\n", "from openpi.models import model as _model\n", "from openpi.policies import droid_policy\n", "from openpi.policies import policy_config as _policy_config\n", "from openpi.shared import download\n", "from openpi.training import config as _config\n", "from openpi.training import data_loader as _data_loader" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Policy inference\n", "\n", "The following example shows how to create a policy from a checkpoint and run inference on a dummy example." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "config = _config.get_config(\"pi0_fast_droid\")\n", "checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_fast_droid\")\n", "\n", "# Create a trained policy.\n", "policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n", "\n", "# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n", "example = droid_policy.make_droid_example()\n", "result = policy.infer(example)\n", "\n", "# Delete the policy to free up memory.\n", "del policy\n", "\n", "print(\"Actions shape:\", result[\"actions\"].shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Working with a live model\n", "\n", "\n", "The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "config = _config.get_config(\"pi0_aloha_sim\")\n", "\n", "checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_aloha_sim\")\n", "key = jax.random.key(0)\n", "\n", "# Create a model from the checkpoint.\n", "model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n", "\n", "# We can create fake observations and actions to test the model.\n", "obs, act = config.model.fake_obs(), config.model.fake_act()\n", "\n", "# Sample actions from the model.\n", "loss = model.compute_loss(key, obs, act)\n", "print(\"Loss shape:\", loss.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we are going to create a data loader and use a real batch of training data to compute the loss." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Reduce the batch size to reduce memory usage.\n", "config = dataclasses.replace(config, batch_size=2)\n", "\n", "# Load a single batch of data. This is the same data that will be used during training.\n", "# NOTE: In order to make this example self-contained, we are skipping the normalization step\n", "# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n", "loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n", "obs, act = next(iter(loader))\n", "\n", "# Sample actions from the model.\n", "loss = model.compute_loss(key, obs, act)\n", "\n", "# Delete the model to free up memory.\n", "del model\n", "\n", "print(\"Loss shape:\", loss.shape)" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/libero/Dockerfile ================================================ # Dockerfile for the LIBERO benchmark. # Build the container: # docker build . -t libero -f examples/libero/Dockerfile # Run the container: # docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0 COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/ RUN apt-get update && \ apt-get install -y \ make \ g++ \ clang \ libosmesa6-dev \ libgl1-mesa-glx \ libegl1 \ libglew-dev \ libglfw3-dev \ libgles2-mesa-dev \ libglib2.0-0 \ libsm6 \ libxrender1 \ libxext6 WORKDIR /app # Copy from the cache instead of linking since it's a mounted volume ENV UV_LINK_MODE=copy # Write the virtual environment outside of the project directory so it doesn't # leak out of the container when we mount the application code. ENV UV_PROJECT_ENVIRONMENT=/.venv # Copy the requirements files so we can install dependencies. # The rest of the project is mounted as a volume, so we don't need to rebuild on changes. # This strategy is best for development-style usage. COPY ./examples/libero/requirements.txt /tmp/requirements.txt COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml # Install python dependencies. RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero # Create a default config file to avoid an input prompt from LIBERO's init script. # https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py ENV LIBERO_CONFIG_PATH=/tmp/libero RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml benchmark_root: /app/third_party/libero/libero/libero bddl_files: /app/third_party/libero/libero/libero/bddl_files init_states: /app/third_party/libero/libero/libero/init_files datasets: /app/third_party/libero/libero/datasets assets: /app/third_party/libero/libero/libero/assets EOF RUN mkdir -p /usr/share/glvnd/egl_vendor.d && echo '{"file_format_version" : "1.0.0", "ICD" : { "library_path" : "libEGL_nvidia.so.0" }}' > /usr/share/glvnd/egl_vendor.d/10_nvidia.json CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py $CLIENT_ARGS"] ================================================ FILE: examples/libero/README.md ================================================ # LIBERO Benchmark This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command. This example requires git submodules to be initialized. Don't forget to run: ```bash git submodule update --init --recursive ``` ## With Docker (recommended) ```bash # Grant access to the X11 server: sudo xhost +local:docker # To run with the default checkpoint and task suite: SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build # To run with glx for Mujoco instead (use this if you have egl errors): MUJOCO_GL=glx SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build ``` You can customize the loaded checkpoint by providing additional `SERVER_ARGS` (see `scripts/serve_policy.py`), and the LIBERO task suite by providing additional `CLIENT_ARGS` (see `examples/libero/main.py`). For example: ```bash # To load a custom checkpoint (located in the top-level openpi/ directory): export SERVER_ARGS="--env LIBERO policy:checkpoint --policy.config pi05_libero --policy.dir ./my_custom_checkpoint" # To run the libero_10 task suite: export CLIENT_ARGS="--args.task-suite-name libero_10" ``` ## Without Docker (not recommended) Terminal window 1: ```bash # Create virtual environment uv venv --python 3.8 examples/libero/.venv source examples/libero/.venv/bin/activate uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match uv pip install -e packages/openpi-client uv pip install -e third_party/libero export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero # Run the simulation python examples/libero/main.py # To run with glx for Mujoco instead (use this if you have egl errors): MUJOCO_GL=glx python examples/libero/main.py ``` Terminal window 2: ```bash # Run the server uv run scripts/serve_policy.py --env LIBERO ``` ## Results If you want to reproduce the following numbers, you can evaluate the checkpoint at `gs://openpi-assets/checkpoints/pi05_libero/`. This checkpoint was trained in openpi with the `pi05_libero` config. | Model | Libero Spatial | Libero Object | Libero Goal | Libero 10 | Average | |-------|---------------|---------------|-------------|-----------|---------| | π0.5 @ 30k (finetuned) | 98.8 | 98.2 | 98.0 | 92.4 | 96.85 ================================================ FILE: examples/libero/compose.yml ================================================ # Run with: # docker compose -f examples/libero/compose.yml up --build services: runtime: image: libero depends_on: - openpi_server build: context: ../.. dockerfile: examples/libero/Dockerfile init: true tty: true network_mode: host privileged: true volumes: - $PWD:/app - ../../data:/data - /tmp/.X11-unix:/tmp/.X11-unix:ro environment: - CLIENT_ARGS - DISPLAY=$DISPLAY - MUJOCO_GL=${MUJOCO_GL:-egl} - MUJOCO_EGL_DEVICE_ID=0 - NVIDIA_DRIVER_CAPABILITIES=all - PYOPENGL_PLATFORM=egl deploy: resources: reservations: devices: - driver: nvidia count: 1 capabilities: [gpu] openpi_server: image: openpi_server build: context: ../.. dockerfile: scripts/docker/serve_policy.Dockerfile init: true tty: true network_mode: host volumes: - $PWD:/app - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets environment: - SERVER_ARGS - OPENPI_DATA_HOME=/openpi_assets - IS_DOCKER=true # Comment out this block if not running on a machine with GPUs. deploy: resources: reservations: devices: - driver: nvidia count: 1 capabilities: [gpu] ================================================ FILE: examples/libero/convert_libero_data_to_lerobot.py ================================================ """ Minimal example script for converting a dataset to LeRobot format. We use the Libero dataset (stored in RLDS) for this example, but it can be easily modified for any other data you have saved in a custom format. Usage: uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data If you want to push your dataset to the Hugging Face Hub, you can use the following command: uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub Note: to run the script, you need to install tensorflow_datasets: `uv pip install tensorflow tensorflow_datasets` You can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds The resulting dataset will get saved to the $HF_LEROBOT_HOME directory. Running this conversion script will take approximately 30 minutes. """ import shutil from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME from lerobot.common.datasets.lerobot_dataset import LeRobotDataset import tensorflow_datasets as tfds import tyro REPO_NAME = "your_hf_username/libero" # Name of the output dataset, also used for the Hugging Face Hub RAW_DATASET_NAMES = [ "libero_10_no_noops", "libero_goal_no_noops", "libero_object_no_noops", "libero_spatial_no_noops", ] # For simplicity we will combine multiple Libero datasets into one training dataset def main(data_dir: str, *, push_to_hub: bool = False): # Clean up any existing dataset in the output directory output_path = HF_LEROBOT_HOME / REPO_NAME if output_path.exists(): shutil.rmtree(output_path) # Create LeRobot dataset, define features to store # OpenPi assumes that proprio is stored in `state` and actions in `action` # LeRobot assumes that dtype of image data is `image` dataset = LeRobotDataset.create( repo_id=REPO_NAME, robot_type="panda", fps=10, features={ "image": { "dtype": "image", "shape": (256, 256, 3), "names": ["height", "width", "channel"], }, "wrist_image": { "dtype": "image", "shape": (256, 256, 3), "names": ["height", "width", "channel"], }, "state": { "dtype": "float32", "shape": (8,), "names": ["state"], }, "actions": { "dtype": "float32", "shape": (7,), "names": ["actions"], }, }, image_writer_threads=10, image_writer_processes=5, ) # Loop over raw Libero datasets and write episodes to the LeRobot dataset # You can modify this for your own data format for raw_dataset_name in RAW_DATASET_NAMES: raw_dataset = tfds.load(raw_dataset_name, data_dir=data_dir, split="train") for episode in raw_dataset: for step in episode["steps"].as_numpy_iterator(): dataset.add_frame( { "image": step["observation"]["image"], "wrist_image": step["observation"]["wrist_image"], "state": step["observation"]["state"], "actions": step["action"], "task": step["language_instruction"].decode(), } ) dataset.save_episode() # Optionally push to the Hugging Face Hub if push_to_hub: dataset.push_to_hub( tags=["libero", "panda", "rlds"], private=False, push_videos=True, license="apache-2.0", ) if __name__ == "__main__": tyro.cli(main) ================================================ FILE: examples/libero/main.py ================================================ import collections import dataclasses import logging import math import pathlib import imageio from libero.libero import benchmark from libero.libero import get_libero_path from libero.libero.envs import OffScreenRenderEnv import numpy as np from openpi_client import image_tools from openpi_client import websocket_client_policy as _websocket_client_policy import tqdm import tyro LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0] LIBERO_ENV_RESOLUTION = 256 # resolution used to render training data @dataclasses.dataclass class Args: ################################################################################################################# # Model server parameters ################################################################################################################# host: str = "0.0.0.0" port: int = 8000 resize_size: int = 224 replan_steps: int = 5 ################################################################################################################# # LIBERO environment-specific parameters ################################################################################################################# task_suite_name: str = ( "libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90 ) num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize i n sim num_trials_per_task: int = 50 # Number of rollouts per task ################################################################################################################# # Utils ################################################################################################################# video_out_path: str = "data/libero/videos" # Path to save videos seed: int = 7 # Random Seed (for reproducibility) def eval_libero(args: Args) -> None: # Set random seed np.random.seed(args.seed) # Initialize LIBERO task suite benchmark_dict = benchmark.get_benchmark_dict() task_suite = benchmark_dict[args.task_suite_name]() num_tasks_in_suite = task_suite.n_tasks logging.info(f"Task suite: {args.task_suite_name}") pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True) if args.task_suite_name == "libero_spatial": max_steps = 220 # longest training demo has 193 steps elif args.task_suite_name == "libero_object": max_steps = 280 # longest training demo has 254 steps elif args.task_suite_name == "libero_goal": max_steps = 300 # longest training demo has 270 steps elif args.task_suite_name == "libero_10": max_steps = 520 # longest training demo has 505 steps elif args.task_suite_name == "libero_90": max_steps = 400 # longest training demo has 373 steps else: raise ValueError(f"Unknown task suite: {args.task_suite_name}") client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port) # Start evaluation total_episodes, total_successes = 0, 0 for task_id in tqdm.tqdm(range(num_tasks_in_suite)): # Get task task = task_suite.get_task(task_id) # Get default LIBERO initial states initial_states = task_suite.get_task_init_states(task_id) # Initialize LIBERO environment and task description env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed) # Start episodes task_episodes, task_successes = 0, 0 for episode_idx in tqdm.tqdm(range(args.num_trials_per_task)): logging.info(f"\nTask: {task_description}") # Reset environment env.reset() action_plan = collections.deque() # Set initial states obs = env.set_init_state(initial_states[episode_idx]) # Setup t = 0 replay_images = [] logging.info(f"Starting episode {task_episodes+1}...") while t < max_steps + args.num_steps_wait: try: # IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects # and we need to wait for them to fall if t < args.num_steps_wait: obs, reward, done, info = env.step(LIBERO_DUMMY_ACTION) t += 1 continue # Get preprocessed image # IMPORTANT: rotate 180 degrees to match train preprocessing img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1]) wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1]) img = image_tools.convert_to_uint8( image_tools.resize_with_pad(img, args.resize_size, args.resize_size) ) wrist_img = image_tools.convert_to_uint8( image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size) ) # Save preprocessed image for replay video replay_images.append(img) if not action_plan: # Finished executing previous action chunk -- compute new chunk # Prepare observations dict element = { "observation/image": img, "observation/wrist_image": wrist_img, "observation/state": np.concatenate( ( obs["robot0_eef_pos"], _quat2axisangle(obs["robot0_eef_quat"]), obs["robot0_gripper_qpos"], ) ), "prompt": str(task_description), } # Query model to get action action_chunk = client.infer(element)["actions"] assert ( len(action_chunk) >= args.replan_steps ), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps." action_plan.extend(action_chunk[: args.replan_steps]) action = action_plan.popleft() # Execute action in environment obs, reward, done, info = env.step(action.tolist()) if done: task_successes += 1 total_successes += 1 break t += 1 except Exception as e: logging.error(f"Caught exception: {e}") break task_episodes += 1 total_episodes += 1 # Save a replay video of the episode suffix = "success" if done else "failure" task_segment = task_description.replace(" ", "_") imageio.mimwrite( pathlib.Path(args.video_out_path) / f"rollout_{task_segment}_{suffix}.mp4", [np.asarray(x) for x in replay_images], fps=10, ) # Log current results logging.info(f"Success: {done}") logging.info(f"# episodes completed so far: {total_episodes}") logging.info(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)") # Log final results logging.info(f"Current task success rate: {float(task_successes) / float(task_episodes)}") logging.info(f"Current total success rate: {float(total_successes) / float(total_episodes)}") logging.info(f"Total success rate: {float(total_successes) / float(total_episodes)}") logging.info(f"Total episodes: {total_episodes}") def _get_libero_env(task, resolution, seed): """Initializes and returns the LIBERO environment, along with the task description.""" task_description = task.language task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution} env = OffScreenRenderEnv(**env_args) env.seed(seed) # IMPORTANT: seed seems to affect object positions even when using fixed initial state return env, task_description def _quat2axisangle(quat): """ Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55 """ # clip quaternion if quat[3] > 1.0: quat[3] = 1.0 elif quat[3] < -1.0: quat[3] = -1.0 den = np.sqrt(1.0 - quat[3] * quat[3]) if math.isclose(den, 0.0): # This is (close to) a zero degree rotation, immediately return return np.zeros(3) return (quat[:3] * 2.0 * math.acos(quat[3])) / den if __name__ == "__main__": logging.basicConfig(level=logging.INFO) tyro.cli(eval_libero) ================================================ FILE: examples/libero/requirements.in ================================================ imageio[ffmpeg] numpy==1.22.4 tqdm tyro PyYaml opencv-python==4.6.0.66 torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 robosuite==1.4.1 matplotlib==3.5.3 ================================================ FILE: examples/libero/requirements.txt ================================================ # This file was autogenerated by uv via the following command: # uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match absl-py==2.1.0 # via mujoco certifi==2024.12.14 # via requests charset-normalizer==3.4.0 # via requests cycler==0.12.1 # via matplotlib docstring-parser==0.16 # via tyro etils==1.3.0 # via mujoco eval-type-backport==0.2.0 # via tyro evdev==1.7.1 # via pynput fonttools==4.55.3 # via matplotlib glfw==1.12.0 # via mujoco idna==3.10 # via requests imageio==2.35.1 # via -r examples/libero/requirements.in imageio-ffmpeg==0.5.1 # via imageio importlib-metadata==8.5.0 # via typeguard importlib-resources==6.4.5 # via etils kiwisolver==1.4.7 # via matplotlib llvmlite==0.36.0 # via numba markdown-it-py==3.0.0 # via rich matplotlib==3.5.3 # via -r examples/libero/requirements.in mdurl==0.1.2 # via markdown-it-py mujoco==3.2.3 # via robosuite numba==0.53.1 # via robosuite numpy==1.22.4 # via # -r examples/libero/requirements.in # imageio # matplotlib # mujoco # numba # opencv-python # robosuite # scipy # torchvision opencv-python==4.6.0.66 # via # -r examples/libero/requirements.in # robosuite packaging==24.2 # via matplotlib pillow==10.4.0 # via # imageio # matplotlib # robosuite # torchvision psutil==6.1.0 # via imageio pygments==2.18.0 # via rich pynput==1.7.7 # via robosuite pyopengl==3.1.7 # via mujoco pyparsing==3.1.4 # via matplotlib python-dateutil==2.9.0.post0 # via matplotlib python-xlib==0.33 # via pynput pyyaml==6.0.2 # via -r examples/libero/requirements.in requests==2.32.3 # via torchvision rich==13.9.4 # via tyro robosuite==1.4.1 # via -r examples/libero/requirements.in scipy==1.10.1 # via robosuite setuptools==75.3.0 # via # imageio-ffmpeg # numba shtab==1.7.1 # via tyro six==1.17.0 # via # pynput # python-dateutil # python-xlib termcolor==2.4.0 # via robosuite torch==1.11.0+cu113 # via # -r examples/libero/requirements.in # torchaudio # torchvision torchaudio==0.11.0+cu113 # via -r examples/libero/requirements.in torchvision==0.12.0+cu113 # via -r examples/libero/requirements.in tqdm==4.67.1 # via -r examples/libero/requirements.in typeguard==4.4.0 # via tyro typing-extensions==4.12.2 # via # etils # rich # torch # torchvision # typeguard # tyro tyro==0.9.2 # via -r examples/libero/requirements.in urllib3==2.2.3 # via requests zipp==3.20.2 # via # etils # importlib-metadata # importlib-resources ================================================ FILE: examples/policy_records.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pathlib\n", "\n", "import numpy as np\n", "\n", "record_path = pathlib.Path(\"../policy_records\")\n", "num_steps = len(list(record_path.glob(\"step_*.npy\")))\n", "\n", "records = []\n", "for i in range(num_steps):\n", " record = np.load(record_path / f\"step_{i}.npy\", allow_pickle=True).item()\n", " records.append(record)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"length of records\", len(records))\n", "print(\"keys in records\", records[0].keys())\n", "\n", "for k in records[0]:\n", " print(f\"{k} shape: {records[0][k].shape}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from PIL import Image\n", "\n", "\n", "def get_image(step: int, idx: int = 0):\n", " img = (255 * records[step][\"inputs/image\"]).astype(np.uint8)\n", " return img[idx].transpose(1, 2, 0)\n", "\n", "\n", "def show_image(step: int, idx_lst: list[int]):\n", " imgs = [get_image(step, idx) for idx in idx_lst]\n", " return Image.fromarray(np.hstack(imgs))\n", "\n", "\n", "for i in range(2):\n", " display(show_image(i, [0]))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "\n", "def get_axis(name, axis):\n", " return np.array([record[name][axis] for record in records])\n", "\n", "\n", "# qpos is [..., 14] of type float:\n", "# 0-5: left arm joint angles\n", "# 6: left arm gripper\n", "# 7-12: right arm joint angles\n", "# 13: right arm gripper\n", "names = [(\"left_joint\", 6), (\"left_gripper\", 1), (\"right_joint\", 6), (\"right_gripper\", 1)]\n", "\n", "\n", "def make_data():\n", " cur_dim = 0\n", " in_data = {}\n", " out_data = {}\n", " for name, dim_size in names:\n", " for i in range(dim_size):\n", " in_data[f\"{name}_{i}\"] = get_axis(\"inputs/qpos\", cur_dim)\n", " out_data[f\"{name}_{i}\"] = get_axis(\"outputs/qpos\", cur_dim)\n", " cur_dim += 1\n", " return pd.DataFrame(in_data), pd.DataFrame(out_data)\n", "\n", "\n", "in_data, out_data = make_data()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for name in in_data.columns:\n", " data = pd.DataFrame({f\"in_{name}\": in_data[name], f\"out_{name}\": out_data[name]})\n", " data.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/simple_client/Dockerfile ================================================ # Dockerfile for the simple client. # Build the container: # docker build . -t simple_client -f examples/simple_client/Dockerfile # Run the container: # docker run --rm -it --network=host -v .:/app simple_client /bin/bash FROM python:3.7-slim COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/ WORKDIR /app # Copy from the cache instead of linking since it's a mounted volume ENV UV_LINK_MODE=copy # Write the virtual environment outside of the project directory so it doesn't # leak out of the container when we mount the application code. ENV UV_PROJECT_ENVIRONMENT=/.venv # Copy the requirements files so we can install dependencies. # The rest of the project is mounted as a volume, so we don't need to rebuild on changes. # This strategy is best for development-style usage. COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml # Install python dependencies. RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS" ================================================ FILE: examples/simple_client/README.md ================================================ # Simple Client A minimal client that sends observations to the server and prints the inference rate. You can specify which runtime environment to use using the `--env` flag. You can see the available options by running: ```bash uv run examples/simple_client/main.py --help ``` ## With Docker ```bash export SERVER_ARGS="--env ALOHA_SIM" docker compose -f examples/simple_client/compose.yml up --build ``` ## Without Docker Terminal window 1: ```bash uv run examples/simple_client/main.py --env DROID ``` Terminal window 2: ```bash uv run scripts/serve_policy.py --env DROID ``` ================================================ FILE: examples/simple_client/compose.yml ================================================ # Run with: # docker compose -f examples/simple_client/compose.yml up --build services: runtime: image: simple_client depends_on: - openpi_server build: context: ../.. dockerfile: examples/simple_client/Dockerfile init: true tty: true network_mode: host volumes: - $PWD:/app environment: - SERVER_ARGS openpi_server: image: openpi_server build: context: ../.. dockerfile: scripts/docker/serve_policy.Dockerfile init: true tty: true network_mode: host volumes: - $PWD:/app - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets environment: - SERVER_ARGS - OPENPI_DATA_HOME=/openpi_assets - IS_DOCKER=true # Comment out this block if not running on a machine with GPUs. deploy: resources: reservations: devices: - driver: nvidia count: 1 capabilities: [gpu] ================================================ FILE: examples/simple_client/main.py ================================================ import dataclasses import enum import logging import pathlib import time import numpy as np from openpi_client import websocket_client_policy as _websocket_client_policy import polars as pl import rich import tqdm import tyro logger = logging.getLogger(__name__) class EnvMode(enum.Enum): """Supported environments.""" ALOHA = "aloha" ALOHA_SIM = "aloha_sim" DROID = "droid" LIBERO = "libero" @dataclasses.dataclass class Args: """Command line arguments.""" # Host and port to connect to the server. host: str = "0.0.0.0" # Port to connect to the server. If None, the server will use the default port. port: int | None = 8000 # API key to use for the server. api_key: str | None = None # Number of steps to run the policy for. num_steps: int = 20 # Path to save the timings to a parquet file. (e.g., timing.parquet) timing_file: pathlib.Path | None = None # Environment to run the policy in. env: EnvMode = EnvMode.ALOHA_SIM class TimingRecorder: """Records timing measurements for different keys.""" def __init__(self) -> None: self._timings: dict[str, list[float]] = {} def record(self, key: str, time_ms: float) -> None: """Record a timing measurement for the given key.""" if key not in self._timings: self._timings[key] = [] self._timings[key].append(time_ms) def get_stats(self, key: str) -> dict[str, float]: """Get statistics for the given key.""" times = self._timings[key] return { "mean": float(np.mean(times)), "std": float(np.std(times)), "p25": float(np.quantile(times, 0.25)), "p50": float(np.quantile(times, 0.50)), "p75": float(np.quantile(times, 0.75)), "p90": float(np.quantile(times, 0.90)), "p95": float(np.quantile(times, 0.95)), "p99": float(np.quantile(times, 0.99)), } def print_all_stats(self) -> None: """Print statistics for all keys in a concise format.""" table = rich.table.Table( title="[bold blue]Timing Statistics[/bold blue]", show_header=True, header_style="bold white", border_style="blue", title_justify="center", ) # Add metric column with custom styling table.add_column("Metric", style="cyan", justify="left", no_wrap=True) # Add statistical columns with consistent styling stat_columns = [ ("Mean", "yellow", "mean"), ("Std", "yellow", "std"), ("P25", "magenta", "p25"), ("P50", "magenta", "p50"), ("P75", "magenta", "p75"), ("P90", "magenta", "p90"), ("P95", "magenta", "p95"), ("P99", "magenta", "p99"), ] for name, style, _ in stat_columns: table.add_column(name, justify="right", style=style, no_wrap=True) # Add rows for each metric with formatted values for key in sorted(self._timings.keys()): stats = self.get_stats(key) values = [f"{stats[key]:.1f}" for _, _, key in stat_columns] table.add_row(key, *values) # Print with custom console settings console = rich.console.Console(width=None, highlight=True) console.print(table) def write_parquet(self, path: pathlib.Path) -> None: """Save the timings to a parquet file.""" logger.info(f"Writing timings to {path}") frame = pl.DataFrame(self._timings) path.parent.mkdir(parents=True, exist_ok=True) frame.write_parquet(path) def main(args: Args) -> None: obs_fn = { EnvMode.ALOHA: _random_observation_aloha, EnvMode.ALOHA_SIM: _random_observation_aloha, EnvMode.DROID: _random_observation_droid, EnvMode.LIBERO: _random_observation_libero, }[args.env] policy = _websocket_client_policy.WebsocketClientPolicy( host=args.host, port=args.port, api_key=args.api_key, ) logger.info(f"Server metadata: {policy.get_server_metadata()}") # Send a few observations to make sure the model is loaded. for _ in range(2): policy.infer(obs_fn()) timing_recorder = TimingRecorder() for _ in tqdm.trange(args.num_steps, desc="Running policy"): inference_start = time.time() action = policy.infer(obs_fn()) timing_recorder.record("client_infer_ms", 1000 * (time.time() - inference_start)) for key, value in action.get("server_timing", {}).items(): timing_recorder.record(f"server_{key}", value) for key, value in action.get("policy_timing", {}).items(): timing_recorder.record(f"policy_{key}", value) timing_recorder.print_all_stats() if args.timing_file is not None: timing_recorder.write_parquet(args.timing_file) def _random_observation_aloha() -> dict: return { "state": np.ones((14,)), "images": { "cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), "cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), "cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), "cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), }, "prompt": "do something", } def _random_observation_droid() -> dict: return { "observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), "observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), "observation/joint_position": np.random.rand(7), "observation/gripper_position": np.random.rand(1), "prompt": "do something", } def _random_observation_libero() -> dict: return { "observation/state": np.random.rand(8), "observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), "observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), "prompt": "do something", } if __name__ == "__main__": logging.basicConfig(level=logging.INFO) main(tyro.cli(Args)) ================================================ FILE: examples/simple_client/requirements.in ================================================ numpy>=1.22.4,<2.0.0 rich tqdm tyro polars ================================================ FILE: examples/simple_client/requirements.txt ================================================ # This file was autogenerated by uv via the following command: # uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.11.9 docstring-parser==0.16 # via tyro markdown-it-py==3.0.0 # via rich mdurl==0.1.2 # via markdown-it-py numpy==1.26.4 # via -r examples/simple_client/requirements.in polars==1.30.0 # via -r examples/simple_client/requirements.in pygments==2.19.1 # via rich rich==14.0.0 # via # -r examples/simple_client/requirements.in # tyro shtab==1.7.2 # via tyro tqdm==4.67.1 # via -r examples/simple_client/requirements.in typeguard==4.4.2 # via tyro typing-extensions==4.13.2 # via # typeguard # tyro tyro==0.9.22 # via -r examples/simple_client/requirements.in ================================================ FILE: examples/ur5/README.md ================================================ # UR5 Example Below we provide an outline of how to implement the key components mentioned in the "Finetune on your data" section of the [README](../README.md) for finetuning on UR5 datasets. First, we will define the `UR5Inputs` and `UR5Outputs` classes, which map the UR5 environment to the model and vice versa. Check the corresponding files in `src/openpi/policies/libero_policy.py` for comments explaining each line. ```python @dataclasses.dataclass(frozen=True) class UR5Inputs(transforms.DataTransformFn): model_type: _model.ModelType = _model.ModelType.PI0 def __call__(self, data: dict) -> dict: # First, concatenate the joints and gripper into the state vector. state = np.concatenate([data["joints"], data["gripper"]]) # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically # stores as float32 (C,H,W), gets skipped for policy inference. base_image = _parse_image(data["base_rgb"]) wrist_image = _parse_image(data["wrist_rgb"]) # Create inputs dict. inputs = { "state": state, "image": { "base_0_rgb": base_image, "left_wrist_0_rgb": wrist_image, # Since there is no right wrist, replace with zeros "right_wrist_0_rgb": np.zeros_like(base_image), }, "image_mask": { "base_0_rgb": np.True_, "left_wrist_0_rgb": np.True_, # Since the "slot" for the right wrist is not used, this mask is set # to False "right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_, }, } if "actions" in data: inputs["actions"] = data["actions"] # Pass the prompt (aka language instruction) to the model. if "prompt" in data: inputs["prompt"] = data["prompt"] return inputs @dataclasses.dataclass(frozen=True) class UR5Outputs(transforms.DataTransformFn): def __call__(self, data: dict) -> dict: # Since the robot has 7 action dimensions (6 DoF + gripper), return the first 7 dims return {"actions": np.asarray(data["actions"][:, :7])} ``` Next, we will define the `UR5DataConfig` class, which defines how to process raw UR5 data from LeRobot dataset for training. For a full example, see the `LeRobotLiberoDataConfig` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py). ```python @dataclasses.dataclass(frozen=True) class LeRobotUR5DataConfig(DataConfigFactory): @override def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: # Boilerplate for remapping keys from the LeRobot dataset. We assume no renaming needed here. repack_transform = _transforms.Group( inputs=[ _transforms.RepackTransform( { "base_rgb": "image", "wrist_rgb": "wrist_image", "joints": "joints", "gripper": "gripper", "prompt": "prompt", } ) ] ) # These transforms are the ones we wrote earlier. data_transforms = _transforms.Group( inputs=[UR5Inputs(action_dim=model_config.action_dim, model_type=model_config.model_type)], outputs=[UR5Outputs()], ) # Convert absolute actions to delta actions. # By convention, we do not convert the gripper action (7th dimension). delta_action_mask = _transforms.make_bool_mask(6, -1) data_transforms = data_transforms.push( inputs=[_transforms.DeltaActions(delta_action_mask)], outputs=[_transforms.AbsoluteActions(delta_action_mask)], ) # Model transforms include things like tokenizing the prompt and action targets # You do not need to change anything here for your own dataset. model_transforms = ModelTransformFactory()(model_config) # We return all data transforms for training and inference. No need to change anything here. return dataclasses.replace( self.create_base_config(assets_dirs), repack_transforms=repack_transform, data_transforms=data_transforms, model_transforms=model_transforms, ) ``` Finally, we define the TrainConfig for our UR5 dataset. Here, we define a config for fine-tuning pi0 on our UR5 dataset. See the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py) for more examples, e.g. for pi0-FAST or for LoRA fine-tuning. ```python TrainConfig( name="pi0_ur5", model=pi0.Pi0Config(), data=LeRobotUR5DataConfig( repo_id="your_username/ur5_dataset", # This config lets us reload the UR5 normalization stats from the base model checkpoint. # Reloading normalization stats can help transfer pre-trained models to new environments. # See the [norm_stats.md](../docs/norm_stats.md) file for more details. assets=AssetsConfig( assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets", asset_id="ur5e", ), base_config=DataConfig( # This flag determines whether we load the prompt (i.e. the task instruction) from the # ``task`` field in the LeRobot dataset. The recommended setting is True. prompt_from_task=True, ), ), # Load the pi0 base model checkpoint. weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"), num_train_steps=30_000, ) ``` ================================================ FILE: packages/openpi-client/pyproject.toml ================================================ [project] name = "openpi-client" version = "0.1.0" requires-python = ">=3.7" dependencies = [ "dm-tree>=0.1.8", "msgpack>=1.0.5", "numpy>=1.22.4,<2.0.0", "pillow>=9.0.0", "tree>=0.2.4", "websockets>=11.0", ] [build-system] requires = ["hatchling"] build-backend = "hatchling.build" [tool.uv] dev-dependencies = ["pytest>=8.3.4"] [tool.ruff] line-length = 120 target-version = "py37" ================================================ FILE: packages/openpi-client/src/openpi_client/__init__.py ================================================ __version__ = "0.1.0" ================================================ FILE: packages/openpi-client/src/openpi_client/action_chunk_broker.py ================================================ from typing import Dict import numpy as np import tree from typing_extensions import override from openpi_client import base_policy as _base_policy class ActionChunkBroker(_base_policy.BasePolicy): """Wraps a policy to return action chunks one-at-a-time. Assumes that the first dimension of all action fields is the chunk size. A new inference call to the inner policy is only made when the current list of chunks is exhausted. """ def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int): self._policy = policy self._action_horizon = action_horizon self._cur_step: int = 0 self._last_results: Dict[str, np.ndarray] | None = None @override def infer(self, obs: Dict) -> Dict: # noqa: UP006 if self._last_results is None: self._last_results = self._policy.infer(obs) self._cur_step = 0 def slicer(x): if isinstance(x, np.ndarray): return x[self._cur_step, ...] else: return x results = tree.map_structure(slicer, self._last_results) self._cur_step += 1 if self._cur_step >= self._action_horizon: self._last_results = None return results @override def reset(self) -> None: self._policy.reset() self._last_results = None self._cur_step = 0 ================================================ FILE: packages/openpi-client/src/openpi_client/base_policy.py ================================================ import abc from typing import Dict class BasePolicy(abc.ABC): @abc.abstractmethod def infer(self, obs: Dict) -> Dict: """Infer actions from observations.""" def reset(self) -> None: """Reset the policy to its initial state.""" pass ================================================ FILE: packages/openpi-client/src/openpi_client/image_tools.py ================================================ import numpy as np from PIL import Image def convert_to_uint8(img: np.ndarray) -> np.ndarray: """Converts an image to uint8 if it is a float image. This is important for reducing the size of the image when sending it over the network. """ if np.issubdtype(img.dtype, np.floating): img = (255 * img).astype(np.uint8) return img def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray: """Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height. Args: images: A batch of images in [..., height, width, channel] format. height: The target height of the image. width: The target width of the image. method: The interpolation method to use. Default is bilinear. Returns: The resized images in [..., height, width, channel]. """ # If the images are already the correct size, return them as is. if images.shape[-3:-1] == (height, width): return images original_shape = images.shape images = images.reshape(-1, *original_shape[-3:]) resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images]) return resized.reshape(*original_shape[:-3], *resized.shape[-3:]) def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image: """Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and width without distortion by padding with zeros. Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c]. """ cur_width, cur_height = image.size if cur_width == width and cur_height == height: return image # No need to resize if the image is already the correct size. ratio = max(cur_width / width, cur_height / height) resized_height = int(cur_height / ratio) resized_width = int(cur_width / ratio) resized_image = image.resize((resized_width, resized_height), resample=method) zero_image = Image.new(resized_image.mode, (width, height), 0) pad_height = max(0, int((height - resized_height) / 2)) pad_width = max(0, int((width - resized_width) / 2)) zero_image.paste(resized_image, (pad_width, pad_height)) assert zero_image.size == (width, height) return zero_image ================================================ FILE: packages/openpi-client/src/openpi_client/image_tools_test.py ================================================ import numpy as np import openpi_client.image_tools as image_tools def test_resize_with_pad_shapes(): # Test case 1: Resize image with larger dimensions images = np.zeros((2, 10, 10, 3), dtype=np.uint8) # Input images of shape (batch_size, height, width, channels) height = 20 width = 20 resized_images = image_tools.resize_with_pad(images, height, width) assert resized_images.shape == (2, height, width, 3) assert np.all(resized_images == 0) # Test case 2: Resize image with smaller dimensions images = np.zeros((3, 30, 30, 3), dtype=np.uint8) height = 15 width = 15 resized_images = image_tools.resize_with_pad(images, height, width) assert resized_images.shape == (3, height, width, 3) assert np.all(resized_images == 0) # Test case 3: Resize image with the same dimensions images = np.zeros((1, 50, 50, 3), dtype=np.uint8) height = 50 width = 50 resized_images = image_tools.resize_with_pad(images, height, width) assert resized_images.shape == (1, height, width, 3) assert np.all(resized_images == 0) # Test case 3: Resize image with odd-numbered padding images = np.zeros((1, 256, 320, 3), dtype=np.uint8) height = 60 width = 80 resized_images = image_tools.resize_with_pad(images, height, width) assert resized_images.shape == (1, height, width, 3) assert np.all(resized_images == 0) ================================================ FILE: packages/openpi-client/src/openpi_client/msgpack_numpy.py ================================================ """Adds NumPy array support to msgpack. msgpack is good for (de)serializing data over a network for multiple reasons: - msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution) - msgpack is widely used and has good cross-language support - msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed languages like Python and JavaScript - msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster than pickle for serializing large arrays using the below strategy The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is that it falls back to pickle for object arrays. """ import functools import msgpack import numpy as np def pack_array(obj): if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"): raise ValueError(f"Unsupported dtype: {obj.dtype}") if isinstance(obj, np.ndarray): return { b"__ndarray__": True, b"data": obj.tobytes(), b"dtype": obj.dtype.str, b"shape": obj.shape, } if isinstance(obj, np.generic): return { b"__npgeneric__": True, b"data": obj.item(), b"dtype": obj.dtype.str, } return obj def unpack_array(obj): if b"__ndarray__" in obj: return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"]) if b"__npgeneric__" in obj: return np.dtype(obj[b"dtype"]).type(obj[b"data"]) return obj Packer = functools.partial(msgpack.Packer, default=pack_array) packb = functools.partial(msgpack.packb, default=pack_array) Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array) unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array) ================================================ FILE: packages/openpi-client/src/openpi_client/msgpack_numpy_test.py ================================================ import numpy as np import pytest import tree from openpi_client import msgpack_numpy def _check(expected, actual): if isinstance(expected, np.ndarray): assert expected.shape == actual.shape assert expected.dtype == actual.dtype assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == "f") else: assert expected == actual @pytest.mark.parametrize( "data", [ 1, # int 1.0, # float "hello", # string np.bool_(True), # boolean scalar np.array([1, 2, 3])[0], # int scalar np.str_("asdf"), # string scalar [1, 2, 3], # list {"key": "value"}, # dict {"key": [1, 2, 3]}, # nested dict np.array(1.0), # 0D array np.array([1, 2, 3], dtype=np.int32), # 1D integer array np.array(["asdf", "qwer"]), # string array np.array([True, False]), # boolean array np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), # 2D float array np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16), # 3D integer array np.array([np.nan, np.inf, -np.inf]), # special float values {"arr": np.array([1, 2, 3]), "nested": {"arr": np.array([4, 5, 6])}}, # nested dict with arrays [np.array([1, 2]), np.array([3, 4])], # list of arrays np.zeros((3, 4, 5), dtype=np.float32), # 3D zeros np.ones((2, 3), dtype=np.float64), # 2D ones with double precision ], ) def test_pack_unpack(data): packed = msgpack_numpy.packb(data) unpacked = msgpack_numpy.unpackb(packed) tree.map_structure(_check, data, unpacked) ================================================ FILE: packages/openpi-client/src/openpi_client/runtime/agent.py ================================================ import abc class Agent(abc.ABC): """An Agent is the thing with agency, i.e. the entity that makes decisions. Agents receive observations about the state of the world, and return actions to take in response. """ @abc.abstractmethod def get_action(self, observation: dict) -> dict: """Query the agent for the next action.""" @abc.abstractmethod def reset(self) -> None: """Reset the agent to its initial state.""" ================================================ FILE: packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py ================================================ from typing_extensions import override from openpi_client import base_policy as _base_policy from openpi_client.runtime import agent as _agent class PolicyAgent(_agent.Agent): """An agent that uses a policy to determine actions.""" def __init__(self, policy: _base_policy.BasePolicy) -> None: self._policy = policy @override def get_action(self, observation: dict) -> dict: return self._policy.infer(observation) def reset(self) -> None: self._policy.reset() ================================================ FILE: packages/openpi-client/src/openpi_client/runtime/environment.py ================================================ import abc class Environment(abc.ABC): """An Environment represents the robot and the environment it inhabits. The primary contract of environments is that they can be queried for observations about their state, and have actions applied to them to change that state. """ @abc.abstractmethod def reset(self) -> None: """Reset the environment to its initial state. This will be called once before starting each episode. """ @abc.abstractmethod def is_episode_complete(self) -> bool: """Allow the environment to signal that the episode is complete. This will be called after each step. It should return `True` if the episode is complete (either successfully or unsuccessfully), and `False` otherwise. """ @abc.abstractmethod def get_observation(self) -> dict: """Query the environment for the current state.""" @abc.abstractmethod def apply_action(self, action: dict) -> None: """Take an action in the environment.""" ================================================ FILE: packages/openpi-client/src/openpi_client/runtime/runtime.py ================================================ import logging import threading import time from openpi_client.runtime import agent as _agent from openpi_client.runtime import environment as _environment from openpi_client.runtime import subscriber as _subscriber class Runtime: """The core module orchestrating interactions between key components of the system.""" def __init__( self, environment: _environment.Environment, agent: _agent.Agent, subscribers: list[_subscriber.Subscriber], max_hz: float = 0, num_episodes: int = 1, max_episode_steps: int = 0, ) -> None: self._environment = environment self._agent = agent self._subscribers = subscribers self._max_hz = max_hz self._num_episodes = num_episodes self._max_episode_steps = max_episode_steps self._in_episode = False self._episode_steps = 0 def run(self) -> None: """Runs the runtime loop continuously until stop() is called or the environment is done.""" for _ in range(self._num_episodes): self._run_episode() # Final reset, this is important for real environments to move the robot to its home position. self._environment.reset() def run_in_new_thread(self) -> threading.Thread: """Runs the runtime loop in a new thread.""" thread = threading.Thread(target=self.run) thread.start() return thread def mark_episode_complete(self) -> None: """Marks the end of an episode.""" self._in_episode = False def _run_episode(self) -> None: """Runs a single episode.""" logging.info("Starting episode...") self._environment.reset() self._agent.reset() for subscriber in self._subscribers: subscriber.on_episode_start() self._in_episode = True self._episode_steps = 0 step_time = 1 / self._max_hz if self._max_hz > 0 else 0 last_step_time = time.time() while self._in_episode: self._step() self._episode_steps += 1 # Sleep to maintain the desired frame rate now = time.time() dt = now - last_step_time if dt < step_time: time.sleep(step_time - dt) last_step_time = time.time() else: last_step_time = now logging.info("Episode completed.") for subscriber in self._subscribers: subscriber.on_episode_end() def _step(self) -> None: """A single step of the runtime loop.""" observation = self._environment.get_observation() action = self._agent.get_action(observation) self._environment.apply_action(action) for subscriber in self._subscribers: subscriber.on_step(observation, action) if self._environment.is_episode_complete() or ( self._max_episode_steps > 0 and self._episode_steps >= self._max_episode_steps ): self.mark_episode_complete() ================================================ FILE: packages/openpi-client/src/openpi_client/runtime/subscriber.py ================================================ import abc class Subscriber(abc.ABC): """Subscribes to events in the runtime. Subscribers can be used to save data, visualize, etc. """ @abc.abstractmethod def on_episode_start(self) -> None: """Called when an episode starts.""" @abc.abstractmethod def on_step(self, observation: dict, action: dict) -> None: """Append a step to the episode.""" @abc.abstractmethod def on_episode_end(self) -> None: """Called when an episode ends.""" ================================================ FILE: packages/openpi-client/src/openpi_client/websocket_client_policy.py ================================================ import logging import time from typing import Dict, Optional, Tuple from typing_extensions import override import websockets.sync.client from openpi_client import base_policy as _base_policy from openpi_client import msgpack_numpy class WebsocketClientPolicy(_base_policy.BasePolicy): """Implements the Policy interface by communicating with a server over websocket. See WebsocketPolicyServer for a corresponding server implementation. """ def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, api_key: Optional[str] = None) -> None: if host.startswith("ws"): self._uri = host else: self._uri = f"ws://{host}" if port is not None: self._uri += f":{port}" self._packer = msgpack_numpy.Packer() self._api_key = api_key self._ws, self._server_metadata = self._wait_for_server() def get_server_metadata(self) -> Dict: return self._server_metadata def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]: logging.info(f"Waiting for server at {self._uri}...") while True: try: headers = {"Authorization": f"Api-Key {self._api_key}"} if self._api_key else None conn = websockets.sync.client.connect( self._uri, compression=None, max_size=None, additional_headers=headers ) metadata = msgpack_numpy.unpackb(conn.recv()) return conn, metadata except ConnectionRefusedError: logging.info("Still waiting for server...") time.sleep(5) @override def infer(self, obs: Dict) -> Dict: # noqa: UP006 data = self._packer.pack(obs) self._ws.send(data) response = self._ws.recv() if isinstance(response, str): # we're expecting bytes; if the server sends a string, it's an error. raise RuntimeError(f"Error in inference server:\n{response}") return msgpack_numpy.unpackb(response) @override def reset(self) -> None: pass ================================================ FILE: pyproject.toml ================================================ [project] name = "openpi" version = "0.1.0" description = "Physical Intelligence open source repo" readme = "README.md" requires-python = ">=3.11" license = { file = "LICENSE" } dependencies = [ "augmax>=0.3.4", "dm-tree>=0.1.8", "einops>=0.8.0", "equinox>=0.11.8", "flatbuffers>=24.3.25", "flax==0.10.2", "fsspec[gcs]>=2024.6.0", "gym-aloha>=0.1.1", "imageio>=2.36.1", "jax[cuda12]==0.5.3", "jaxtyping==0.2.36", "lerobot", "ml_collections==1.0.0", "numpy>=1.22.4,<2.0.0", "numpydantic>=1.6.6", "opencv-python>=4.10.0.84", "openpi-client", "orbax-checkpoint==0.11.13", "pillow>=11.0.0", "sentencepiece>=0.2.0", "torch==2.7.1", "tqdm-loggable>=0.2", "typing-extensions>=4.12.2", "tyro>=0.9.5", "wandb>=0.19.1", "filelock>=3.16.1", "beartype==0.19.0", "treescope>=0.1.7", "transformers==4.53.2", "rich>=14.0.0", "polars>=1.30.0", ] [project.urls] Repository = "https://github.com/Physical-Intelligence/openpi" [dependency-groups] dev = [ "pytest>=8.3.4", "ruff>=0.8.6", "pre-commit>=4.0.1", "ipykernel>=6.29.5", "ipywidgets>=8.1.5", "matplotlib>=3.10.0", "pynvml>=12.0.0", ] rlds = [ "dlimp", "tensorflow-cpu==2.15.0", "tensorflow-datasets==4.9.9", ] [tool.uv] override-dependencies = ["ml-dtypes==0.4.1", "tensorstore==0.1.74"] [tool.uv.sources] openpi-client = { workspace = true } lerobot = { git = "https://github.com/huggingface/lerobot", rev = "0cf864870cf29f4738d3ade893e6fd13fbd7cdb5" } dlimp = { git = "https://github.com/kvablack/dlimp", rev = "ad72ce3a9b414db2185bc0b38461d4101a65477a" } [tool.uv.workspace] members = ["packages/*"] [tool.ruff] line-length = 120 target-version = "py311" extend-exclude = ["docker", "third_party", "src/openpi/models_pytorch/transformers_replace/*"] [tool.ruff.lint] # https://docs.astral.sh/ruff/rules/ select = [ "B", "C4", "DTZ", "E4", "E7", "E9", "F", "FBT", "FURB", "I", "ICN", "ISC", "LOG", "N", "PD", "PERF", "PIE", "PLC", "PLE", "PLR1", "PLR5", "PLW", "PT", "Q", "RET", "RUF", "SIM", "SLF", "T10", "T20", "UP", "W", ] ignore = [ "F722", # Conflicts with array typing. "T201", # We use print statements. "PD008", # Lots of false positives. "ISC001", # Disabling to support ruff format. "LOG015", # Use logger.info. ] unfixable = [ "B905", # Fix defaults to strict=False, which is not what we want. ] [tool.ruff.lint.isort] force-single-line = true force-sort-within-sections = true single-line-exclusions = ["collections.abc", "typing", "typing_extensions"] known-third-party = ["wandb"] [build-system] requires = ["hatchling"] build-backend = "hatchling.build" [tool.pytest.ini_options] markers = ["manual: should be run manually."] testpaths = ["src", "scripts", "packages"] ================================================ FILE: scripts/__init__.py ================================================ ================================================ FILE: scripts/compute_norm_stats.py ================================================ """Compute normalization statistics for a config. This script is used to compute the normalization statistics for a given config. It will compute the mean and standard deviation of the data in the dataset and save it to the config assets directory. """ import numpy as np import tqdm import tyro import openpi.models.model as _model import openpi.shared.normalize as normalize import openpi.training.config as _config import openpi.training.data_loader as _data_loader import openpi.transforms as transforms class RemoveStrings(transforms.DataTransformFn): def __call__(self, x: dict) -> dict: return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)} def create_torch_dataloader( data_config: _config.DataConfig, action_horizon: int, batch_size: int, model_config: _model.BaseModelConfig, num_workers: int, max_frames: int | None = None, ) -> tuple[_data_loader.Dataset, int]: if data_config.repo_id is None: raise ValueError("Data config must have a repo_id") dataset = _data_loader.create_torch_dataset(data_config, action_horizon, model_config) dataset = _data_loader.TransformedDataset( dataset, [ *data_config.repack_transforms.inputs, *data_config.data_transforms.inputs, # Remove strings since they are not supported by JAX and are not needed to compute norm stats. RemoveStrings(), ], ) if max_frames is not None and max_frames < len(dataset): num_batches = max_frames // batch_size shuffle = True else: num_batches = len(dataset) // batch_size shuffle = False data_loader = _data_loader.TorchDataLoader( dataset, local_batch_size=batch_size, num_workers=num_workers, shuffle=shuffle, num_batches=num_batches, ) return data_loader, num_batches def create_rlds_dataloader( data_config: _config.DataConfig, action_horizon: int, batch_size: int, max_frames: int | None = None, ) -> tuple[_data_loader.Dataset, int]: dataset = _data_loader.create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=False) dataset = _data_loader.IterableTransformedDataset( dataset, [ *data_config.repack_transforms.inputs, *data_config.data_transforms.inputs, # Remove strings since they are not supported by JAX and are not needed to compute norm stats. RemoveStrings(), ], is_batched=True, ) if max_frames is not None and max_frames < len(dataset): num_batches = max_frames // batch_size else: # NOTE: this length is currently hard-coded for DROID. num_batches = len(dataset) // batch_size data_loader = _data_loader.RLDSDataLoader( dataset, num_batches=num_batches, ) return data_loader, num_batches def main(config_name: str, max_frames: int | None = None): config = _config.get_config(config_name) data_config = config.data.create(config.assets_dirs, config.model) if data_config.rlds_data_dir is not None: data_loader, num_batches = create_rlds_dataloader( data_config, config.model.action_horizon, config.batch_size, max_frames ) else: data_loader, num_batches = create_torch_dataloader( data_config, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames ) keys = ["state", "actions"] stats = {key: normalize.RunningStats() for key in keys} for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"): for key in keys: stats[key].update(np.asarray(batch[key])) norm_stats = {key: stats.get_statistics() for key, stats in stats.items()} output_path = config.assets_dirs / data_config.repo_id print(f"Writing stats to: {output_path}") normalize.save(output_path, norm_stats) if __name__ == "__main__": tyro.cli(main) ================================================ FILE: scripts/docker/compose.yml ================================================ # Run with: # docker compose -f scripts/docker/compose.yml up --build services: openpi_server: image: openpi_server build: context: ../.. dockerfile: scripts/docker/serve_policy.Dockerfile init: true tty: true network_mode: host # Populate configured openpi data home to /openpi_assets inside the container. # Populate aws credential inside the container. volumes: - $PWD:/app - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets environment: - SERVER_ARGS - OPENPI_DATA_HOME=/openpi_assets - IS_DOCKER=true # Comment out this block if not running on a machine with GPUs. deploy: resources: reservations: devices: - driver: nvidia count: 1 capabilities: [gpu] ================================================ FILE: scripts/docker/install_docker_ubuntu22.sh ================================================ #!/bin/bash # Add Docker's official GPG key: sudo apt-get update sudo apt-get install -y ca-certificates curl sudo install -m 0755 -d /etc/apt/keyrings sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc sudo chmod a+r /etc/apt/keyrings/docker.asc # Add the repository to Apt sources: echo \ "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \ $(. /etc/os-release && echo "$VERSION_CODENAME") stable" | sudo tee /etc/apt/sources.list.d/docker.list >/dev/null sudo apt-get update sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin # Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc). # See https://docs.docker.com/engine/install/linux-postinstall/ username=$(whoami) sudo usermod -aG docker $username # Configure docker to start automatically on system boot. sudo systemctl enable docker.service sudo systemctl enable containerd.service # https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5 if [ ~/.docker/config.json ]; then sed -i 's/credsStore/credStore/g' ~/.docker/config.json fi echo "" echo "********************************************************************" echo "**** Restart to allow Docker permission changes to take effect. ****" echo "********************************************************************" echo "" ================================================ FILE: scripts/docker/install_nvidia_container_toolkit.sh ================================================ #!/bin/bash # Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs. # NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list # NVIDIA's documentation omits 'sudo' in the following command, but it is required. sudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list sudo apt-get update sudo apt-get install -y nvidia-container-toolkit sudo nvidia-ctk runtime configure --runtime=docker sudo systemctl restart docker ================================================ FILE: scripts/docker/serve_policy.Dockerfile ================================================ # Dockerfile for serving a PI policy. # Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container # Build the container: # docker build . -t openpi_server -f scripts/docker/serve_policy.Dockerfile # Run the container: # docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0 COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/ WORKDIR /app # Needed because LeRobot uses git-lfs. RUN apt-get update && apt-get install -y git git-lfs linux-headers-generic build-essential clang # Copy from the cache instead of linking since it's a mounted volume ENV UV_LINK_MODE=copy # Write the virtual environment outside of the project directory so it doesn't # leak out of the container when we mount the application code. ENV UV_PROJECT_ENVIRONMENT=/.venv # Install the project's dependencies using the lockfile and settings RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,source=uv.lock,target=uv.lock \ --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ --mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \ --mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \ GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev # Copy transformers_replace files while preserving directory structure COPY src/openpi/models_pytorch/transformers_replace/ /tmp/transformers_replace/ RUN /.venv/bin/python -c "import transformers; print(transformers.__file__)" | xargs dirname | xargs -I{} cp -r /tmp/transformers_replace/* {} && rm -rf /tmp/transformers_replace CMD /bin/bash -c "uv run scripts/serve_policy.py $SERVER_ARGS" ================================================ FILE: scripts/serve_policy.py ================================================ import dataclasses import enum import logging import socket import tyro from openpi.policies import policy as _policy from openpi.policies import policy_config as _policy_config from openpi.serving import websocket_policy_server from openpi.training import config as _config class EnvMode(enum.Enum): """Supported environments.""" ALOHA = "aloha" ALOHA_SIM = "aloha_sim" DROID = "droid" LIBERO = "libero" @dataclasses.dataclass class Checkpoint: """Load a policy from a trained checkpoint.""" # Training config name (e.g., "pi0_aloha_sim"). config: str # Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000"). dir: str @dataclasses.dataclass class Default: """Use the default policy for the given environment.""" @dataclasses.dataclass class Args: """Arguments for the serve_policy script.""" # Environment to serve the policy for. This is only used when serving default policies. env: EnvMode = EnvMode.ALOHA_SIM # If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default # prompt. default_prompt: str | None = None # Port to serve the policy on. port: int = 8000 # Record the policy's behavior for debugging. record: bool = False # Specifies how to load the policy. If not provided, the default policy for the environment will be used. policy: Checkpoint | Default = dataclasses.field(default_factory=Default) # Default checkpoints that should be used for each environment. DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = { EnvMode.ALOHA: Checkpoint( config="pi05_aloha", dir="gs://openpi-assets/checkpoints/pi05_base", ), EnvMode.ALOHA_SIM: Checkpoint( config="pi0_aloha_sim", dir="gs://openpi-assets/checkpoints/pi0_aloha_sim", ), EnvMode.DROID: Checkpoint( config="pi05_droid", dir="gs://openpi-assets/checkpoints/pi05_droid", ), EnvMode.LIBERO: Checkpoint( config="pi05_libero", dir="gs://openpi-assets/checkpoints/pi05_libero", ), } def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy: """Create a default policy for the given environment.""" if checkpoint := DEFAULT_CHECKPOINT.get(env): return _policy_config.create_trained_policy( _config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt ) raise ValueError(f"Unsupported environment mode: {env}") def create_policy(args: Args) -> _policy.Policy: """Create a policy from the given arguments.""" match args.policy: case Checkpoint(): return _policy_config.create_trained_policy( _config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt ) case Default(): return create_default_policy(args.env, default_prompt=args.default_prompt) def main(args: Args) -> None: policy = create_policy(args) policy_metadata = policy.metadata # Record the policy's behavior. if args.record: policy = _policy.PolicyRecorder(policy, "policy_records") hostname = socket.gethostname() local_ip = socket.gethostbyname(hostname) logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip) server = websocket_policy_server.WebsocketPolicyServer( policy=policy, host="0.0.0.0", port=args.port, metadata=policy_metadata, ) server.serve_forever() if __name__ == "__main__": logging.basicConfig(level=logging.INFO, force=True) main(tyro.cli(Args)) ================================================ FILE: scripts/train.py ================================================ import dataclasses import functools import logging import platform from typing import Any import etils.epath as epath import flax.nnx as nnx from flax.training import common_utils import flax.traverse_util as traverse_util import jax import jax.experimental import jax.numpy as jnp import numpy as np import optax import tqdm_loggable.auto as tqdm import wandb import openpi.models.model as _model import openpi.shared.array_typing as at import openpi.shared.nnx_utils as nnx_utils import openpi.training.checkpoints as _checkpoints import openpi.training.config as _config import openpi.training.data_loader as _data_loader import openpi.training.optimizer as _optimizer import openpi.training.sharding as sharding import openpi.training.utils as training_utils import openpi.training.weight_loaders as _weight_loaders def init_logging(): """Custom logging format for better readability.""" level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"} class CustomFormatter(logging.Formatter): def format(self, record): record.levelname = level_mapping.get(record.levelname, record.levelname) return super().format(record) formatter = CustomFormatter( fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)", datefmt="%H:%M:%S", ) logger = logging.getLogger() logger.setLevel(logging.INFO) logger.handlers[0].setFormatter(formatter) def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True): if not enabled: wandb.init(mode="disabled") return ckpt_dir = config.checkpoint_dir if not ckpt_dir.exists(): raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.") if resuming: run_id = (ckpt_dir / "wandb_id.txt").read_text().strip() wandb.init(id=run_id, resume="must", project=config.project_name) else: wandb.init( name=config.exp_name, config=dataclasses.asdict(config), project=config.project_name, ) (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id) if log_code: wandb.run.log_code(epath.Path(__file__).parent.parent) def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params: """Loads and validates the weights. Returns a loaded subset of the weights.""" loaded_params = loader.load(params_shape) at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True) # Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned. return traverse_util.unflatten_dict( {k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)} ) @at.typecheck def init_train_state( config: _config.TrainConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, *, resume: bool ) -> tuple[training_utils.TrainState, Any]: tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None) def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState: rng, model_rng = jax.random.split(rng) # initialize the model (and its parameters). model = config.model.create(model_rng) # Merge the partial params into the model. if partial_params is not None: graphdef, state = nnx.split(model) # This will produce an error if the partial params are not a subset of the state. state.replace_by_pure_dict(partial_params) model = nnx.merge(graphdef, state) params = nnx.state(model) # Convert frozen params to bfloat16. params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16))) return training_utils.TrainState( step=0, params=params, model_def=nnx.graphdef(model), tx=tx, opt_state=tx.init(params.filter(config.trainable_filter)), ema_decay=config.ema_decay, ema_params=None if config.ema_decay is None else params, ) train_state_shape = jax.eval_shape(init, init_rng) state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True) if resume: return train_state_shape, state_sharding partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict()) replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) # Initialize the train state and mix in the partial params. train_state = jax.jit( init, donate_argnums=(1,), # donate the partial params buffer. in_shardings=replicated_sharding, out_shardings=state_sharding, )(init_rng, partial_params) return train_state, state_sharding @at.typecheck def train_step( config: _config.TrainConfig, rng: at.KeyArrayLike, state: training_utils.TrainState, batch: tuple[_model.Observation, _model.Actions], ) -> tuple[training_utils.TrainState, dict[str, at.Array]]: model = nnx.merge(state.model_def, state.params) model.train() @at.typecheck def loss_fn( model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions ): chunked_loss = model.compute_loss(rng, observation, actions, train=True) return jnp.mean(chunked_loss) train_rng = jax.random.fold_in(rng, state.step) observation, actions = batch # Filter out frozen params. diff_state = nnx.DiffState(0, config.trainable_filter) loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions) params = state.params.filter(config.trainable_filter) updates, new_opt_state = state.tx.update(grads, state.opt_state, params) new_params = optax.apply_updates(params, updates) # Update the model in place and return the new full state. nnx.update(model, new_params) new_params = nnx.state(model) new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state) if state.ema_decay is not None: new_state = dataclasses.replace( new_state, ema_params=jax.tree.map( lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params ), ) # Filter out params that aren't kernels. kernel_params = nnx.state( model, nnx.All( nnx.Param, nnx.Not(nnx_utils.PathRegex(".*/(bias|scale|pos_embedding|input_embedding)")), lambda _, x: x.value.ndim > 1, ), ) info = { "loss": loss, "grad_norm": optax.global_norm(grads), "param_norm": optax.global_norm(kernel_params), } return new_state, info def main(config: _config.TrainConfig): init_logging() logging.info(f"Running on: {platform.node()}") if config.batch_size % jax.device_count() != 0: raise ValueError( f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}." ) jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser())) rng = jax.random.key(config.seed) train_rng, init_rng = jax.random.split(rng) mesh = sharding.make_mesh(config.fsdp_devices) data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS)) replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir( config.checkpoint_dir, keep_period=config.keep_period, overwrite=config.overwrite, resume=config.resume, ) init_wandb(config, resuming=resuming, enabled=config.wandb_enabled) data_loader = _data_loader.create_data_loader( config, sharding=data_sharding, shuffle=True, ) data_iter = iter(data_loader) batch = next(data_iter) logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}") # Log images from first batch to sanity check. images_to_log = [ wandb.Image(np.concatenate([np.array(img[i]) for img in batch[0].images.values()], axis=1)) for i in range(min(5, len(next(iter(batch[0].images.values()))))) ] wandb.log({"camera_views": images_to_log}, step=0) train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming) jax.block_until_ready(train_state) logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}") if resuming: train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader) ptrain_step = jax.jit( functools.partial(train_step, config), in_shardings=(replicated_sharding, train_state_sharding, data_sharding), out_shardings=(train_state_sharding, replicated_sharding), donate_argnums=(1,), ) start_step = int(train_state.step) pbar = tqdm.tqdm( range(start_step, config.num_train_steps), initial=start_step, total=config.num_train_steps, dynamic_ncols=True, ) infos = [] for step in pbar: with sharding.set_mesh(mesh): train_state, info = ptrain_step(train_rng, train_state, batch) infos.append(info) if step % config.log_interval == 0: stacked_infos = common_utils.stack_forest(infos) reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos)) info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items()) pbar.write(f"Step {step}: {info_str}") wandb.log(reduced_info, step=step) infos = [] batch = next(data_iter) if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1: _checkpoints.save_state(checkpoint_manager, train_state, data_loader, step) logging.info("Waiting for checkpoint manager to finish") checkpoint_manager.wait_until_finished() if __name__ == "__main__": main(_config.cli()) ================================================ FILE: scripts/train_pytorch.py ================================================ """ PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support. This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs entirely in PyTorch using the `PI0Pytorch` model and your existing config/data pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`. Usage Single GPU: python scripts/train_pytorch.py --exp_name --save_interval Example: python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint Multi-GPU (single node): torchrun --standalone --nnodes=1 --nproc_per_node= scripts/train_pytorch.py --exp_name Example: torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume Multi-Node Training: torchrun \ --nnodes= --nproc_per_node= --node_rank= \ --master_addr= --master_port= \ scripts/train_pytorch.py --exp_name= --save_interval """ import dataclasses import gc import logging import os import platform import shutil import time import jax import numpy as np import safetensors.torch import torch import torch.distributed as dist import torch.nn.parallel import tqdm import wandb import openpi.models.pi0_config import openpi.models_pytorch.pi0_pytorch import openpi.shared.normalize as _normalize import openpi.training.config as _config import openpi.training.data_loader as _data def init_logging(): level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"} class CustomFormatter(logging.Formatter): def format(self, record): record.levelname = level_mapping.get(record.levelname, record.levelname) return super().format(record) formatter = CustomFormatter( fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)", datefmt="%H:%M:%S", ) logger = logging.getLogger() logger.setLevel(logging.INFO) if not logger.handlers: ch = logging.StreamHandler() ch.setFormatter(formatter) logger.addHandler(ch) else: logger.handlers[0].setFormatter(formatter) def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True): """Initialize wandb logging.""" if not enabled: wandb.init(mode="disabled") return ckpt_dir = config.checkpoint_dir if not ckpt_dir.exists(): raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.") if resuming: run_id = (ckpt_dir / "wandb_id.txt").read_text().strip() wandb.init(id=run_id, resume="must", project=config.project_name) else: wandb.init( name=config.exp_name, config=dataclasses.asdict(config), project=config.project_name, ) (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id) def setup_ddp(): world_size = int(os.environ.get("WORLD_SIZE", "1")) use_ddp = world_size > 1 if use_ddp and not torch.distributed.is_initialized(): backend = "nccl" if torch.cuda.is_available() else "gloo" torch.distributed.init_process_group(backend=backend, init_method="env://") # Set up debugging environment variables for DDP issues if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None: os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO" local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0"))) device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): torch.cuda.set_device(device) return use_ddp, local_rank, device def cleanup_ddp(): if torch.distributed.is_initialized(): torch.distributed.barrier() torch.distributed.destroy_process_group() def set_seed(seed: int, local_rank: int): torch.manual_seed(seed + local_rank) np.random.seed(seed + local_rank) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed + local_rank) def build_datasets(config: _config.TrainConfig): # Use the unified data loader with PyTorch framework data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True) return data_loader, data_loader.data_config() def get_model_state_dict(model): """Get state dict from model, handling DDP wrapper.""" return ( model.module.state_dict() if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model.state_dict() ) def get_model_parameters(model): """Get parameters from model, handling DDP wrapper.""" return ( model.module.parameters() if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model.parameters() ) def save_checkpoint(model, optimizer, global_step, config, is_main, data_config): """Save a checkpoint with model state, optimizer state, and metadata.""" if not is_main: return # Only save if it's time to save or if it's the final step if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1: # Create temporary directory for atomic checkpoint saving final_ckpt_dir = config.checkpoint_dir / f"{global_step}" tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}" # Remove any existing temp directory and create new one if tmp_ckpt_dir.exists(): shutil.rmtree(tmp_ckpt_dir) tmp_ckpt_dir.mkdir(parents=True, exist_ok=True) # Save model state using safetensors (handle shared tensors) model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors") # Save optimizer state using PyTorch format torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt") # Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues) metadata = { "global_step": global_step, "config": dataclasses.asdict(config), "timestamp": time.time(), } torch.save(metadata, tmp_ckpt_dir / "metadata.pt") # save norm stats norm_stats = data_config.norm_stats if norm_stats is not None and data_config.asset_id is not None: _normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats) # Atomically move temp directory to final location if final_ckpt_dir.exists(): shutil.rmtree(final_ckpt_dir) tmp_ckpt_dir.rename(final_ckpt_dir) logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}") # Log checkpoint to wandb if config.wandb_enabled: wandb.log({"checkpoint_step": global_step}, step=global_step) def load_checkpoint(model, optimizer, checkpoint_dir, device): """Load the latest checkpoint and return the global step.""" checkpoint_steps = [ int(d.name) for d in checkpoint_dir.iterdir() if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_") ] if not checkpoint_steps: raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}") latest_step = max(checkpoint_steps) ckpt_dir = checkpoint_dir / f"{latest_step}" # Clear memory before loading checkpoints if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() log_memory_usage(device, latest_step, "before_loading_checkpoint") try: # Load model state with error handling logging.info("Loading model state...") safetensors_path = ckpt_dir / "model.safetensors" if safetensors_path.exists(): model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device)) logging.info("Loaded model state from safetensors format") else: raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}") torch.cuda.empty_cache() gc.collect() log_memory_usage(device, latest_step, "after_loading_model") # Load optimizer state with error handling logging.info("Loading optimizer state...") optimizer_path = ckpt_dir / "optimizer.pt" if optimizer_path.exists(): optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False) logging.info("Loaded optimizer state from pt format") else: raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}") optimizer.load_state_dict(optimizer_state_dict) del optimizer_state_dict torch.cuda.empty_cache() gc.collect() log_memory_usage(device, latest_step, "after_loading_optimizer") # Load metadata logging.info("Loading metadata...") metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False) global_step = metadata.get("global_step", latest_step) del metadata torch.cuda.empty_cache() gc.collect() log_memory_usage(device, latest_step, "after_loading_metadata") logging.info(f"Successfully loaded all checkpoint components from step {latest_step}") return global_step except RuntimeError as e: if "out of memory" in str(e): # Clear memory and provide detailed error message torch.cuda.empty_cache() gc.collect() logging.error(f"Out of memory error while loading checkpoint: {e!s}") log_memory_usage(device, latest_step, "after_oom_error") raise RuntimeError( "Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True" ) from e raise def get_latest_checkpoint_step(checkpoint_dir): """Get the latest checkpoint step number from a checkpoint directory.""" checkpoint_steps = [ int(d.name) for d in checkpoint_dir.iterdir() if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_") ] return max(checkpoint_steps) if checkpoint_steps else None def log_memory_usage(device, step, phase="unknown"): """Log detailed memory usage information.""" if not torch.cuda.is_available(): return memory_allocated = torch.cuda.memory_allocated(device) / 1e9 memory_reserved = torch.cuda.memory_reserved(device) / 1e9 memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device) memory_free = memory_free / 1e9 # Get more detailed memory info memory_stats = torch.cuda.memory_stats(device) max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9 max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9 # Get DDP info if available ddp_info = "" if dist.is_initialized(): ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}" logging.info( f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}" ) def train_loop(config: _config.TrainConfig): use_ddp, local_rank, device = setup_ddp() is_main = (not use_ddp) or (dist.get_rank() == 0) set_seed(config.seed, local_rank) # Initialize checkpoint directory and wandb resuming = False if config.resume: # Find checkpoint directory based on experiment name exp_checkpoint_dir = config.checkpoint_dir if exp_checkpoint_dir.exists(): # Use validation to find the latest working checkpoint latest_step = get_latest_checkpoint_step(exp_checkpoint_dir) if latest_step is not None: resuming = True logging.info( f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}" ) else: raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume") else: raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume") elif config.overwrite and config.checkpoint_dir.exists(): shutil.rmtree(config.checkpoint_dir) logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}") # Create checkpoint directory with experiment name if not resuming: # For new runs, create experiment-specific checkpoint directory exp_checkpoint_dir = config.checkpoint_dir exp_checkpoint_dir.mkdir(parents=True, exist_ok=True) logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}") else: # For resume, checkpoint_dir is already set to the experiment directory logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}") # Initialize wandb (only on main process) if is_main: init_wandb(config, resuming=resuming, enabled=config.wandb_enabled) # Build data loader using the unified data loader # Calculate effective batch size per GPU for DDP # For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size world_size = torch.distributed.get_world_size() if use_ddp else 1 effective_batch_size = config.batch_size // world_size logging.info( f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})" ) # Pass the original batch size to data loader - it will handle DDP splitting internally loader, data_config = build_datasets(config) # Log sample images to wandb on first batch if is_main and config.wandb_enabled and not resuming: # Create a separate data loader for sample batch to avoid consuming the main loader sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False) sample_batch = next(iter(sample_data_loader)) # Convert observation and actions to torch tensors observation, actions = sample_batch sample_batch = observation.to_dict() sample_batch["actions"] = actions # Create sample images for wandb images_to_log = [] # Get batch size from the first image tensor batch_size = next(iter(sample_batch["image"].values())).shape[0] for i in range(min(5, batch_size)): # Concatenate all camera views horizontally for this batch item # Convert from NCHW to NHWC format for wandb img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1) img_concatenated = img_concatenated.cpu().numpy() images_to_log.append(wandb.Image(img_concatenated)) wandb.log({"camera_views": images_to_log}, step=0) # Clear sample batch from memory aggressively del sample_batch, observation, actions, images_to_log, img_concatenated del sample_data_loader # Also delete the sample data loader gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() logging.info("Cleared sample batch and data loader from memory") # Build model if not isinstance(config.model, openpi.models.pi0_config.Pi0Config): # Convert dataclass to Pi0Config if needed model_cfg = openpi.models.pi0_config.Pi0Config( dtype=config.pytorch_training_precision, action_dim=config.model.action_dim, action_horizon=config.model.action_horizon, max_token_len=config.model.max_token_len, paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"), action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"), pi05=getattr(config.model, "pi05", False), ) else: model_cfg = config.model # Update dtype to match pytorch_training_precision object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision) model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device) if hasattr(model, "gradient_checkpointing_enable"): enable_gradient_checkpointing = True model.gradient_checkpointing_enable() logging.info("Enabled gradient checkpointing for memory optimization") else: enable_gradient_checkpointing = False logging.info("Gradient checkpointing is not supported for this model") # Log initial memory usage after model creation if is_main and torch.cuda.is_available(): log_memory_usage(device, 0, "after_model_creation") # Enable memory optimizations for large-scale training if world_size >= 8: torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # Set memory allocation configuration os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True" logging.info("Enabled memory optimizations for 8+ GPU training") if use_ddp: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[device.index] if device.type == "cuda" else None, find_unused_parameters=True, # Disable for memory efficiency gradient_as_bucket_view=True, # Enable for memory efficiency static_graph=world_size >= 8, # Enable for 8+ GPUs ) # Load weights from weight_loader if specified (for fine-tuning) if config.pytorch_weight_path is not None: logging.info(f"Loading weights from: {config.pytorch_weight_path}") model_path = os.path.join(config.pytorch_weight_path, "model.safetensors") safetensors.torch.load_model( (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path ) logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}") # Optimizer + learning rate schedule from config warmup_steps = config.lr_schedule.warmup_steps peak_lr = config.lr_schedule.peak_lr decay_steps = config.lr_schedule.decay_steps end_lr = config.lr_schedule.decay_lr # Create optimizer with config parameters optim = torch.optim.AdamW( model.parameters(), lr=peak_lr, betas=(config.optimizer.b1, config.optimizer.b2), eps=config.optimizer.eps, weight_decay=config.optimizer.weight_decay, ) # Load checkpoint if resuming global_step = 0 if resuming: global_step = load_checkpoint(model, optim, config.checkpoint_dir, device) logging.info(f"Resumed training from step {global_step}") def lr_schedule(step: int): if step < warmup_steps: # Match JAX behavior: start from peak_lr / (warmup_steps + 1) init_lr = peak_lr / (warmup_steps + 1) return init_lr + (peak_lr - init_lr) * step / warmup_steps # cosine decay progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps)) cos = 0.5 * (1 + np.cos(np.pi * progress)) return end_lr + (peak_lr - end_lr) * cos model.train() start_time = time.time() infos = [] # Collect stats over log interval if is_main: logging.info( f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}" ) logging.info( f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}" ) logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}") logging.info( f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}" ) logging.info( f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}" ) logging.info("EMA is not supported for PyTorch training") logging.info(f"Training precision: {model_cfg.dtype}") # Training loop - iterate until we reach num_train_steps pbar = ( tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main) if is_main else None ) while global_step < config.num_train_steps: # Set epoch for distributed training if use_ddp and hasattr(loader, "set_epoch"): loader.set_epoch(global_step // len(loader)) for observation, actions in loader: # Check if we've reached the target number of steps if global_step >= config.num_train_steps: break # The unified data loader returns (observation, actions) tuple observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901 actions = actions.to(torch.float32) # noqa: PLW2901 actions = actions.to(device) # noqa: PLW2901 # Update LR for pg in optim.param_groups: pg["lr"] = lr_schedule(global_step) # Forward pass losses = model(observation, actions) # Ensure losses is a tensor and handle different return types if isinstance(losses, list | tuple): losses = torch.stack(losses) elif not isinstance(losses, torch.Tensor): losses = torch.tensor(losses, device=device, dtype=torch.float32) loss = losses.mean() # Backward pass loss.backward() # Log memory usage after backward pass if global_step < 5 and is_main and torch.cuda.is_available(): log_memory_usage(device, global_step, "after_backward") # Gradient clipping grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm) # Optimizer step optim.step() optim.zero_grad(set_to_none=True) # Clear gradients more aggressively for param in model.parameters(): if param.grad is not None: param.grad.detach_() param.grad = None # Collect stats if is_main: infos.append( { "loss": loss.item(), "learning_rate": optim.param_groups[0]["lr"], "grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm, } ) if is_main and (global_step % config.log_interval == 0): elapsed = time.time() - start_time # Average stats over log interval avg_loss = sum(info["loss"] for info in infos) / len(infos) avg_lr = sum(info["learning_rate"] for info in infos) / len(infos) avg_grad_norm = None if any("grad_norm" in info for info in infos): vals = [ info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None ] if len(vals) > 0: avg_grad_norm = sum(vals) / len(vals) logging.info( f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s" if avg_grad_norm is not None else f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s" ) # Log to wandb if config.wandb_enabled and len(infos) > 0: log_payload = { "loss": avg_loss, "learning_rate": avg_lr, "step": global_step, "time_per_step": elapsed / config.log_interval, } if avg_grad_norm is not None: log_payload["grad_norm"] = avg_grad_norm wandb.log(log_payload, step=global_step) start_time = time.time() infos = [] # Reset stats collection global_step += 1 # Save checkpoint using the new mechanism save_checkpoint(model, optim, global_step, config, is_main, data_config) # Update progress bar if pbar is not None: pbar.update(1) pbar.set_postfix( {"loss": f"{loss.item():.4f}", "lr": f"{optim.param_groups[0]['lr']:.2e}", "step": global_step} ) # Close progress bar if pbar is not None: pbar.close() # Finish wandb run if is_main and config.wandb_enabled: wandb.finish() cleanup_ddp() def main(): init_logging() config = _config.cli() train_loop(config) if __name__ == "__main__": main() ================================================ FILE: scripts/train_test.py ================================================ import dataclasses import os import pathlib import pytest os.environ["JAX_PLATFORMS"] = "cpu" from openpi.training import config as _config from . import train @pytest.mark.parametrize("config_name", ["debug"]) def test_train(tmp_path: pathlib.Path, config_name: str): config = dataclasses.replace( _config._CONFIGS_DICT[config_name], # noqa: SLF001 batch_size=2, checkpoint_base_dir=str(tmp_path / "checkpoint"), exp_name="test", overwrite=False, resume=False, num_train_steps=2, log_interval=1, ) train.main(config) # test resuming config = dataclasses.replace(config, resume=True, num_train_steps=4) train.main(config) ================================================ FILE: src/openpi/__init__.py ================================================ ================================================ FILE: src/openpi/conftest.py ================================================ import os import pynvml import pytest def set_jax_cpu_backend_if_no_gpu() -> None: try: pynvml.nvmlInit() pynvml.nvmlShutdown() except pynvml.NVMLError: # No GPU found. os.environ["JAX_PLATFORMS"] = "cpu" def pytest_configure(config: pytest.Config) -> None: set_jax_cpu_backend_if_no_gpu() ================================================ FILE: src/openpi/models/__init__.py ================================================ ================================================ FILE: src/openpi/models/gemma.py ================================================ # Copyright 2024 Big Vision Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Gemma adaptation for Pi, taken from big_vision. We follow this einsum axis naming convention: B: batch T: query length S: k/v length N: num query heads K: num k/v heads G: num query heads per k/v head H: head dim D: d_model ("features") """ from collections.abc import Sequence import dataclasses from typing import Literal, TypeAlias import einops import flax.linen as nn import jax import jax.numpy as jnp import openpi.models.lora as lora import openpi.shared.array_typing as at import openpi.training.sharding as sharding PALIGEMMA_VOCAB_SIZE = 257_152 @dataclasses.dataclass class Config: width: int depth: int mlp_dim: int num_heads: int num_kv_heads: int head_dim: int lora_configs: dict[str, lora.LoRAConfig] = dataclasses.field(default_factory=dict) Variant = Literal["dummy", "gemma_300m", "gemma_300m_lora", "gemma_2b", "gemma_2b_lora"] def get_config(variant: Variant) -> Config: """Returns config for specified gemma variant.""" if variant == "dummy": return Config( width=64, depth=4, mlp_dim=128, num_heads=8, num_kv_heads=1, head_dim=16, ) if variant == "gemma_300m": # 311M params return Config( width=1024, depth=18, mlp_dim=4096, num_heads=8, num_kv_heads=1, head_dim=256, ) if variant == "gemma_2b": return Config( width=2048, depth=18, mlp_dim=16_384, num_heads=8, num_kv_heads=1, head_dim=256, ) if variant == "gemma_2b_lora": return Config( width=2048, depth=18, mlp_dim=16_384, num_heads=8, num_kv_heads=1, head_dim=256, lora_configs={"attn": lora.LoRAConfig(rank=16, alpha=16.0), "ffn": lora.LoRAConfig(rank=16, alpha=16.0)}, ) if variant == "gemma_300m_lora": # 311M params return Config( width=1024, depth=18, mlp_dim=4096, num_heads=8, num_kv_heads=1, head_dim=256, lora_configs={"attn": lora.LoRAConfig(rank=32, alpha=32.0), "ffn": lora.LoRAConfig(rank=32, alpha=32.0)}, ) raise ValueError(f"Unknown variant: {variant}") @at.typecheck class RMSNorm(nn.Module): @nn.compact def __call__(self, x, cond): dtype = x.dtype # original dtype, could be half-precision var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32 normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32 if cond is None: # regular RMSNorm scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1])) normed_inputs = normed_inputs * ( 1 + scale ) # scale by learned parameter in float32 (matches Flax implementation) return normed_inputs.astype(dtype), None # return in original dtype # adaptive RMSNorm modulation = nn.Dense(x.shape[-1] * 3, kernel_init=nn.initializers.zeros, dtype=dtype)(cond) scale, shift, gate = jnp.split(modulation[:, None, :], 3, axis=-1) normed_inputs = normed_inputs * (1 + scale) + shift # scale and shift in float32 return normed_inputs.astype(dtype), gate @at.typecheck class Embedder(nn.Module): """Embedder module.""" vocab_size: int embed_dim: int def setup(self): self.input_embedding_table = self.param( "input_embedding", nn.initializers.normal(), (self.vocab_size, self.embed_dim), ) def encode(self, x): x = self.input_embedding_table[(x,)] x *= jnp.sqrt(self.embed_dim).astype(x.dtype) return x def decode(self, x): return jnp.dot(x, self.input_embedding_table.T) @at.typecheck class Attention(nn.Module): """Attention module.""" configs: Sequence[Config] @nn.compact def __call__(self, xs, positions, attn_mask, kv_cache): # all experts must share the same head dim, num heads, and num kv heads for self-attention to work assert all(config.head_dim == self.configs[0].head_dim for config in self.configs) assert all(config.num_heads == self.configs[0].num_heads for config in self.configs) assert all(config.num_kv_heads == self.configs[0].num_kv_heads for config in self.configs) dtype = next(x.dtype for x in xs if x is not None) # original dtype, could be half-precision qkvs = [] for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)): if x is None: continue if config.num_kv_heads == config.num_heads: qkv_einsum = lora.Einsum( shape=(3, config.num_heads, config.width, config.head_dim), name=_name("qkv_einsum", i), init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), lora_config=config.lora_configs.get("attn"), ) qkvs.append(qkv_einsum("BSD,3KDH->3BSKH", x)) else: q_einsum = lora.Einsum( shape=(config.num_heads, config.width, config.head_dim), name=_name("q_einsum", i), init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), lora_config=config.lora_configs.get("attn"), ) q = q_einsum("BTD,NDH->BTNH", x) kv_einsum = lora.Einsum( shape=(2, config.num_kv_heads, config.width, config.head_dim), name=_name("kv_einsum", i), init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), lora_config=config.lora_configs.get("attn"), ) k, v = kv_einsum("BSD,2KDH->2BSKH", x) qkvs.append((q, k, v)) q, k, v = (jnp.concatenate(y, axis=1) for y in zip(*qkvs, strict=True)) q = _apply_rope(q, positions=positions) q *= self.configs[0].head_dim ** -0.5 k = _apply_rope(k, positions=positions) # should still be half-precision here (if input was half-precision) assert q.dtype == k.dtype == v.dtype == dtype if kv_cache is not None: cache_k, cache_v = kv_cache k = jnp.concatenate([cache_k, k], axis=1) v = jnp.concatenate([cache_v, v], axis=1) q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.configs[0].num_kv_heads) logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32) if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]): raise ValueError( f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}" ) # big_neg = jnp.finfo(logits.dtype).min big_neg = -2.3819763e38 # See gemma/modules.py masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg) probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype) encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v) encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H") out = [] start = 0 for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)): if x is not None: end = start + x.shape[1] out_einsum = lora.Einsum( shape=(config.num_heads, config.head_dim, config.width), name=_name("attn_vec_einsum", i), init_fn=nn.initializers.lecun_normal(in_axis=(-3, -2), out_axis=-1), lora_config=config.lora_configs.get("attn"), ) out.append(out_einsum("BTNH,NHD->BTD", encoded[:, start:end])) start = end else: out.append(None) return out, (k, v) @at.typecheck class FeedForward(nn.Module): """Feed forward module.""" features: int hidden_dim: int @nn.compact def __call__(self, x): dtype = x.dtype # original dtype, could be half-precision w_gating = self.param( "gating_einsum", nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), (2, self.features, self.hidden_dim), ).astype(dtype) ff_gate = jnp.dot(x, w_gating[0]) gate_value = nn.gelu(ff_gate) ff1 = jnp.dot(x, w_gating[1]) activations = gate_value * ff1 w_linear = self.param( "linear", nn.initializers.lecun_normal(in_axis=-2, out_axis=-1), (self.hidden_dim, self.features), ).astype(dtype) outputs = jnp.dot(activations, w_linear) assert outputs.dtype == dtype return outputs @at.typecheck class Block(nn.Module): """Transformer block.""" configs: tuple[Config, ...] dropout: float = 0.0 dropout_bdims: tuple[int, ...] = () @nn.compact def __call__(self, xs, kv_cache, positions, attn_mask, adarms_cond, deterministic=True): # noqa: FBT002 xs = sharding.activation_sharding_constraint(xs) drop = nn.Dropout(self.dropout, self.dropout_bdims) if self.dropout else lambda x, _: x attn = Attention(configs=self.configs, name="attn") pre_attn = [] gates = [] for i, x in enumerate(xs): if x is not None: x, gate = RMSNorm(name=_name("pre_attention_norm", i))(x, adarms_cond[i]) # noqa: PLW2901 pre_attn.append(x) gates.append(gate if x is not None else None) pre_attn = sharding.activation_sharding_constraint(pre_attn) post_attn, kv_cache = attn(pre_attn, positions, attn_mask, kv_cache) post_attn = jax.tree.map(lambda x: drop(x, deterministic), post_attn) post_attn = sharding.activation_sharding_constraint(post_attn) xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, post_attn, gates, strict=True)] xs = sharding.activation_sharding_constraint(xs) out = [] gates = [] for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)): if x is not None: x, gate = RMSNorm(name=_name("pre_ffw_norm", i))(x, adarms_cond[i]) # noqa: PLW2901 x = lora.FeedForward( # noqa: PLW2901 features=config.width, hidden_dim=config.mlp_dim, name=_name("mlp", i), lora_config=config.lora_configs.get("ffn"), )(x) out.append(x) gates.append(gate if x is not None else None) out = sharding.activation_sharding_constraint(out) out = jax.tree.map(lambda x: drop(x, deterministic), out) xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, out, gates, strict=True)] xs = sharding.activation_sharding_constraint(xs) return xs, kv_cache KVCache: TypeAlias = tuple[at.Float[at.Array, "l b _t _k _h"], at.Float[at.Array, "l b _t _v _h"]] @at.typecheck class Module(nn.Module): """Transformer model, supporting a mixture of different weights for different tokens.""" configs: Sequence[Config] # list of configs, one for each expert embed_dtype: str dropout: float = 0.0 dropout_bdims: tuple[int, ...] = () # Every float is dropped independently. adarms: bool = False def setup(self): # all experts must have the same depth assert all(config.depth == self.configs[0].depth for config in self.configs) self.embedder = Embedder( vocab_size=PALIGEMMA_VOCAB_SIZE, embed_dim=self.configs[0].width, # embedder for first expert only name="embedder", ) block_cls = nn.remat( Block, prevent_cse=False, static_argnums=(5,), # 0=self, 6=deterministic policy=jax.checkpoint_policies.nothing_saveable, ) self.layers = nn.scan( block_cls, variable_axes={"params": 0}, split_rngs={"params": True, "dropout": True}, in_axes=( 0, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, ), # 0=kv_cache, 1=positions, 2=mask, 3=adarms_cond, 4=deterministic length=self.configs[0].depth, )( configs=self.configs, dropout=self.dropout, dropout_bdims=self.dropout_bdims, ) self.final_norms = [RMSNorm(name=_name("final_norm", i)) for i in range(len(self.configs))] @at.typecheck def embed(self, tokens: at.Int[at.Array, "b t"]) -> at.Float[at.Array, "b t d"]: return self.embedder.encode(tokens).astype(self.embed_dtype) @at.typecheck def __call__( self, # list of token arrays, one for each expert, or None if that expert should not be run embedded: Sequence[at.Float[at.Array, "b _t _d"] | None], positions: at.Int[at.Array, "b t"], mask: at.Bool[at.Array, "b t s"], adarms_cond: Sequence[at.Float[at.Array, "b _d"] | None] | None = None, *, kv_cache: KVCache | None = None, deterministic: bool = True, ) -> tuple[Sequence[at.Float[at.Array, "b _t _d"] | None], KVCache]: embedded = jax.tree.map(lambda e: e.astype(self.embed_dtype), embedded) mask = jnp.asarray(mask)[:, None, :, :] if adarms_cond is None: adarms_cond = [None] * len(self.configs) embedded, kv_cache = self.layers(embedded, kv_cache, positions, mask, adarms_cond, deterministic) assert all(e.dtype == jnp.dtype(self.embed_dtype) for e in embedded if e is not None) return [ f(e, a)[0] if e is not None else e for f, e, a in zip(self.final_norms, embedded, adarms_cond, strict=True) ], kv_cache def init(self, use_adarms: Sequence[bool]): """Convenience method for initializing all parameters, necessary due to the quirks of linen.""" self.embed(jnp.zeros((1, 1), dtype=jnp.int32)) self( [jnp.zeros((1, 1, c.width)) for c in self.configs], jnp.zeros((1, len(self.configs)), dtype=jnp.int32), jnp.zeros((1, len(self.configs), len(self.configs)), dtype=bool), adarms_cond=[jnp.zeros((1, c.width)) if u else None for u, c in zip(use_adarms, self.configs, strict=True)], ) def _apply_rope(x, *, positions, max_wavelength=10_000): """Applies RoPE positions [B, L] to x [B, L, H, D].""" freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32) timescale = max_wavelength**freq_exponents radians = positions[..., None] / timescale[None, None, :] radians = radians[..., None, :] assert radians.dtype == jnp.float32 # radians.shape = [...,L,1,d=D/2] sin, cos = jnp.sin(radians), jnp.cos(radians) x1, x2 = jnp.split(x, 2, axis=-1) res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) assert res.dtype == jnp.float32 # The original bigvision impl allows RoPE to upcast to float32. It is then immediately downcast again to the cache # dtype when in inference mode (but not in training mode). I don't think any of this was intentional. Based on the # original DeepMind impl, as well as the widely-used transformers impl, it is ok to always downcast back to bfloat16 # here. return res.astype(x.dtype) def _name(name, i): # we name layers like this because we want the first expert's weights to have no suffix (e.g., "attn"), so that they # can be loaded seamlessly from the existing PaliGemma checkpoint. subsequent experts will have a suffix (e.g., # "attn_1") and their weights will be initialized from scratch. in practice, we only use two experts -- PaliGemma, # and the action expert. if i == 0: return name return f"{name}_{i}" def _gated_residual(x, y, gate): assert (x is None) == (y is None) if x is None: return None if gate is None: return x + y return x + y * gate ================================================ FILE: src/openpi/models/gemma_fast.py ================================================ # Copyright 2024 Big Vision Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Gemma model implementation from big_vision/models/ppp/gemma.py (with small modifications for NNX compatibility) Used for FAST autoregressive policies. """ import dataclasses from typing import Literal, TypeAlias import einops import flax.linen as nn import jax import jax.numpy as jnp import ml_collections import openpi.models.lora as lora import openpi.shared.array_typing as at Variant = Literal["gemma_2b", "gemma_2b_lora"] def get_config(variant): """Returns config for specified gemma variant.""" if variant == "gemma_2b": return ml_collections.ConfigDict( { "variant": variant, "width": 2048, "depth": 18, "mlp_dim": 16_384, "num_heads": 8, "num_kv_heads": 1, "head_dim": 256, "norm_eps": 1e-6, "vocab_size": 257_152, "scan": True, "remat_policy": "nothing_saveable", } ) if variant == "gemma_2b_lora": return ml_collections.ConfigDict( { "variant": variant, "width": 2048, "depth": 18, "mlp_dim": 16_384, "num_heads": 8, "num_kv_heads": 1, "head_dim": 256, "norm_eps": 1e-6, "vocab_size": 257_152, "scan": True, "remat_policy": "nothing_saveable", "lora_configs": { "attn": lora.LoRAConfig(rank=16, alpha=16.0), "ffn": lora.LoRAConfig(rank=16, alpha=16.0), }, } ) raise ValueError(f"Unknown variant: {variant}") @at.typecheck class Einsum(nn.Module): shape: tuple[int, ...] @nn.compact def __call__(self, eqn, x): dtype = x.dtype # original dtype, could be half-precision w = self.param("w", nn.initializers.zeros_init(), self.shape).astype(dtype) return jnp.einsum(eqn, x, w) @at.typecheck class RMSNorm(nn.Module): @nn.compact def __call__(self, x): dtype = x.dtype # original dtype, could be half-precision scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1])) var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32 normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32 normed_inputs = normed_inputs * ( 1 + scale ) # scale by learned parameter in float32 (matches Flax implementation) return normed_inputs.astype(dtype) # return in original dtype @at.typecheck class Embedder(nn.Module): """Embedder module.""" vocab_size: int embed_dim: int def setup(self): self.input_embedding_table = self.param( "input_embedding", nn.initializers.zeros_init(), (self.vocab_size, self.embed_dim), ) def encode(self, x): x = self.input_embedding_table[(x,)] x *= jnp.sqrt(self.embed_dim).astype(x.dtype) return x def decode(self, x): return jnp.dot(x, self.input_embedding_table.T) @at.typecheck class Attention(nn.Module): """Attention module.""" num_heads: int num_kv_heads: int features: int head_dim: int cache_dtype: str | None = None lora_config: lora.LoRAConfig | None = None def setup(self): if self.num_kv_heads == self.num_heads: self.qkv_einsum = lora.Einsum( shape=(3, self.num_heads, self.features, self.head_dim), name="qkv_einsum", init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), lora_config=self.lora_config, ) else: self.q_einsum = lora.Einsum( shape=(self.num_heads, self.features, self.head_dim), name="q_einsum", init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), lora_config=self.lora_config, ) self.kv_einsum = lora.Einsum( shape=(2, self.num_kv_heads, self.features, self.head_dim), name="kv_einsum", init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), lora_config=self.lora_config, ) self.attn_vec_einsum = lora.Einsum( shape=(self.num_heads, self.head_dim, self.features), name="attn_vec_einsum", init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), lora_config=self.lora_config, ) def _init_cache(self, k, v, cache_size): """Initialize KV cache""" prefill_len = k.shape[1] pad_width = ((0, 0), (0, cache_size - prefill_len), (0, 0), (0, 0)) cache_dtype = self.cache_dtype or k.dtype k_cache = jnp.pad(k.astype(cache_dtype), pad_width) v_cache = jnp.pad(v.astype(cache_dtype), pad_width) idx = jnp.zeros((k.shape[0],), dtype=jnp.int32) + prefill_len return idx, k_cache, v_cache def _update_cache(self, k, v, idx, k_cache, v_cache): """Update KV cache with new values""" assert k.shape[1] == 1, "Only support kv-cache updates of length 1" indices = (0, idx[0], 0, 0) cache_dtype = self.cache_dtype or k.dtype k_new = jax.lax.dynamic_update_slice(k_cache, k.astype(cache_dtype), indices) v_new = jax.lax.dynamic_update_slice(v_cache, v.astype(cache_dtype), indices) idx_new = idx + 1 return idx_new, k_new, v_new @nn.compact def __call__(self, x, positions, attn_mask, kv_cache, decode, deterministic=True): # noqa: FBT002 dtype = x.dtype # original dtype, could be half-precision if self.num_kv_heads == self.num_heads: q, k, v = self.qkv_einsum("BSD,3KDH->3BSKH", x) else: q = self.q_einsum("BTD,NDH->BTNH", x) k, v = self.kv_einsum("BSD,2KDH->2BSKH", x) q = _apply_rope(q, positions=positions) # promotes to float32 q *= self.head_dim**-0.5 k = _apply_rope(k, positions=positions) # promotes to float32 if kv_cache is None: idx, k_cache, v_cache = self._init_cache(k, v, attn_mask.shape[-1]) else: idx, k_cache, v_cache = kv_cache idx, k_cache, v_cache = self._update_cache(k, v, idx, k_cache, v_cache) k, v = k_cache, v_cache kv_cache = (idx, k_cache, v_cache) q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.num_kv_heads) logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32) if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]): raise ValueError( f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}" ) # big_neg = jnp.finfo(logits.dtype).min big_neg = -2.3819763e38 # See gemma/modules.py masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg) probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype) encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v) encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H") return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache @at.typecheck class Block(nn.Module): """Transformer block.""" num_heads: int num_kv_heads: int embed_dim: int head_dim: int hidden_dim: int dropout: float = 0.0 dropout_bdims: tuple[int, ...] = () cache_dtype: str | None = None lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict) def setup(self): self.pre_attention_norm = RMSNorm() self.attn = Attention( num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, features=self.embed_dim, head_dim=self.head_dim, cache_dtype=self.cache_dtype, lora_config=self.lora_configs.get("attn"), ) self.pre_ffw_norm = RMSNorm() self.mlp = lora.FeedForward( features=self.embed_dim, hidden_dim=self.hidden_dim, name="mlp", lora_config=self.lora_configs.get("ffn") ) if self.dropout: self.drop = nn.Dropout(self.dropout, self.dropout_bdims) else: self.drop = lambda x, _: x def __call__(self, x, kv_cache, positions, attn_mask, decode, deterministic=True): # noqa: FBT002 x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb")) inputs_normalized = self.pre_attention_norm(x) attn_output, kv_cache = self.attn(inputs_normalized, positions, attn_mask, kv_cache, decode, deterministic) attn_output = self.drop(attn_output, deterministic) attn_output += x residual = attn_output attn_output = self.pre_ffw_norm(attn_output) outputs = self.mlp(attn_output) outputs = self.drop(outputs, deterministic) outputs = residual + outputs return outputs, kv_cache KVCache: TypeAlias = tuple[at.Int[at.Array, " b"], at.Float[at.Array, "b _t _k _h"], at.Float[at.Array, "b _t _v _h"]] @at.typecheck class Module(nn.Module): """gemma model.""" variant: str width: int depth: int mlp_dim: int num_heads: int num_kv_heads: int head_dim: int norm_eps: float vocab_size: int embed_dtype: str dropout: float = 0.0 dropout_bdims: tuple[int, ...] = () # Every float is dropped independently. cache_dtype: str | None = None scan: bool = False remat_policy: str = "none" lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict) @nn.compact def __call__( self, tokens=None, embedded_prefix=None, embed_only=False, # noqa: FBT002 pre_logits=None, positions=None, mask=None, decode=False, # noqa: FBT002 kv_cache=None, deterministic=True, # noqa: FBT002 return_prelogits=False, # noqa: FBT002 ): """Embed only, or complete forward pass. Args: tokens: Embedded, then and appended to `embedded_prefix`. Can be None. embedded_prefix: Optional prefix that is already embedded. embed_only: Whether to compute embeddings only. pre_logits: If present computes logits from pre_logits and returns. positions: Optional `[B, T]` allows to specify the absolute position of the tokens. mask: Optional attention mask `[B, T, S]`. decode: Whether to use kv-cache. Caller must pass masks and positions. deterministic: Forwarded to all dropout layers. return_prelogits: Whether to return the pre-logits. Returns: If `embed_only=False`, then `(logits, out)` will be returned. If `embed_only=True`, then the embeddings will be returned. If `return_prelogits=True`, then the pre-logits will be returned. """ out = {} embedder = Embedder(vocab_size=self.vocab_size, embed_dim=self.width, name="embedder") if pre_logits is not None: x = out["pre_logits"] = pre_logits logits = out["logits"] = embedder.decode(x) return logits, out x = [] if embedded_prefix is not None: x.append(embedded_prefix) if tokens is not None: x.append(embedder.encode(tokens)) x = jnp.concatenate(x, axis=-2) x = x.astype(self.embed_dtype) batch_size, seq_len, width = x.shape if embed_only: return x if decode: assert positions is not None and mask is not None, ( # noqa: PT018 "Must explicitly pass positions and mask for decoding." ) if positions is None: positions = jnp.arange(seq_len).astype(jnp.int32)[None, :] assert positions.shape[1] == x.shape[1], (positions.shape, x.shape) if mask is None: mask = nn.attention.make_causal_mask(jnp.ones([batch_size, seq_len])) if mask.ndim == 3: mask = mask[:, None, :, :] cache_size = max(seq_len, mask.shape[-1]) assert mask.shape == (batch_size, 1, seq_len, cache_size), mask.shape if self.remat_policy == "none": block_cls = Block else: block_cls = nn.remat( Block, prevent_cse=not self.scan, static_argnums=(5, 6), # 0=self, 5=decode, 6=deterministic policy=getattr(jax.checkpoint_policies, self.remat_policy), ) block_kw = { "num_heads": self.num_heads, "head_dim": self.head_dim, "num_kv_heads": self.num_kv_heads, "embed_dim": width, "hidden_dim": self.mlp_dim, "dropout": self.dropout, "dropout_bdims": self.dropout_bdims, "cache_dtype": self.cache_dtype, "lora_configs": self.lora_configs, } layers = self.scope.push("layers") blocks = [ nn.scan( block_cls, variable_axes={"params": 0}, split_rngs={"params": True, "dropout": True}, in_axes=(0, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), # 0=kv_cache, 1=positions, 2=mask length=self.depth, )(parent=layers, **block_kw) ] for block in blocks: x, kv_cache = block(x, kv_cache, positions, mask, decode, deterministic) assert x.dtype == jnp.dtype(self.embed_dtype) # Sanity check. out["encoded"] = x x = RMSNorm(name="final_norm")(x) out["pre_logits"] = x if return_prelogits: return x, kv_cache, out x = embedder.decode(x) out["logits"] = x return x, kv_cache, out def init(self): """Convenience method for initializing all parameters, necessary due to the quirks of linen.""" self(jnp.zeros((1, 1), dtype=jnp.int32)) def _apply_rope(x, *, positions, max_wavelength=10_000): """Applies RoPE positions [B, L] to x [B, L, H, D].""" freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32) timescale = max_wavelength**freq_exponents radians = positions[..., None] / timescale[None, None, :] radians = radians[..., None, :] assert radians.dtype == jnp.float32 # radians.shape = [...,L,1,d=D/2] sin, cos = jnp.sin(radians), jnp.cos(radians) x1, x2 = jnp.split(x, 2, axis=-1) res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) assert res.dtype == jnp.float32 return res ================================================ FILE: src/openpi/models/lora.py ================================================ import math import re import flax.linen as nn import flax.struct as struct import jax.numpy as jnp import openpi.shared.array_typing as at @struct.dataclass class LoRAConfig: """Configuration for LoRA.""" # LoRA rank. rank: int # LoRA scaling factor. alpha: float = 1.0 # Initialization function for LoRA parameters. init_fn: nn.initializers.Initializer = nn.initializers.normal(stddev=0.01) # Enable rank-stabilized LoRA: https://arxiv.org/pdf/2312.03732 rslora: bool = False # Axes in the weight to apply LoRA to. Should typically be the last two axes. axes: tuple[int, int] = (-2, -1) # Axis label which is used by LoRA in einsum equations. Must not be present in the original equation. label: str = "L" @property def scaling_value(self) -> float: return self.alpha / math.sqrt(self.rank) if self.rslora else self.alpha / self.rank class Einsum(nn.Module): """Einsum with LoRA support. Can be used as a drop-in replacement for the Gemma Einsum.""" # Shape of the weight. shape: tuple[int, ...] # Initialization function for the weight. init_fn: nn.initializers.Initializer = nn.initializers.zeros # If not None, apply LoRA to the weight. lora_config: LoRAConfig | None = None def setup(self): self.w = self.param("w", self.init_fn, self.shape) if config := self.lora_config: # Setup LoRA parameters. shape_a, shape_b = list(self.shape), list(self.shape) shape_a[config.axes[1]] = config.rank shape_b[config.axes[0]] = config.rank self.w_a = self.param("lora_a", config.init_fn, shape_a) self.w_b = self.param("lora_b", config.init_fn, shape_b) @nn.compact def __call__(self, eqn: str, x): dtype = x.dtype # original dtype, could be half-precision result = jnp.einsum(eqn, x, self.w.astype(dtype)) if config := self.lora_config: eqn_a, eqn_b = self._make_lora_eqns(eqn) lora = jnp.einsum(eqn_a, x, self.w_a.astype(dtype)) lora = jnp.einsum(eqn_b, lora, self.w_b.astype(dtype)) result = result + lora * config.scaling_value return result def _make_lora_eqns(self, eqn: str) -> tuple[str, str]: if "L" in eqn: raise ValueError(f"L already in eqn: {eqn}") if not (m := re.match("(.*),(.*)->(.*)", eqn)): raise ValueError(f"Unsupported einsum eqn: {eqn}") lhs, rhs, out = m.groups() assert self.lora_config is not None a_label, b_label = (rhs[x] for x in self.lora_config.axes) label = self.lora_config.label a_rhs = rhs.replace(b_label, label) a_out = out.replace(b_label, label) eqn_a = f"{lhs},{a_rhs}->{a_out}" b_rhs = rhs.replace(a_label, label) eqn_b = f"{a_out},{b_rhs}->{out}" return eqn_a, eqn_b class FeedForward(nn.Module): """Feed forward module.""" features: int hidden_dim: int # If not None, apply LoRA to the weight. lora_config: LoRAConfig | None = None def setup(self): self.w_gating = self.param( "gating_einsum", nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), (2, self.features, self.hidden_dim), ) self.w_linear = self.param( "linear", nn.initializers.lecun_normal(in_axis=-2, out_axis=-1), (self.hidden_dim, self.features), ) self.w_gating_lora = None self.w_linear_lora = None if self.lora_config: # Setup LoRA parameters. # TODO: follow up with a simplified init_fn api. self.w_gating_lora = ( self.param("gating_einsum_lora_a", self.lora_config.init_fn, (2, self.features, self.lora_config.rank)), self.param( "gating_einsum_lora_b", self.lora_config.init_fn, (2, self.lora_config.rank, self.hidden_dim) ), ) self.w_linear_lora = ( self.param("linear_lora_a", self.lora_config.init_fn, (self.hidden_dim, self.lora_config.rank)), self.param("linear_lora_b", self.lora_config.init_fn, (self.lora_config.rank, self.features)), ) @nn.compact def __call__(self, x): dtype = x.dtype # original dtype, could be half-precision ff_gate = self._dot( x, self.w_gating[0], None if self.w_gating_lora is None else (self.w_gating_lora[0][0], self.w_gating_lora[1][0]), ) gate_value = nn.gelu(ff_gate) ff1 = self._dot( x, self.w_gating[1], None if self.w_gating_lora is None else (self.w_gating_lora[0][1], self.w_gating_lora[1][1]), ) activations = gate_value * ff1 outputs = self._dot(activations, self.w_linear, self.w_linear_lora) assert outputs.dtype == dtype return outputs def _dot(self, x: at.Array, w: at.Array, lora_weights: tuple[at.Array, at.Array] | None) -> at.Array: base = jnp.dot(x, w.astype(x.dtype)) if lora_weights is None: return base return base + jnp.dot(jnp.dot(x, lora_weights[0].astype(x.dtype)), lora_weights[1].astype(x.dtype)) ================================================ FILE: src/openpi/models/lora_test.py ================================================ import flax.linen as nn import jax import jax.numpy as jnp import openpi.models.lora as lora def test_lora_einsum_params_shape(): shape = (3, 8, 32, 4) # (3KDH) einsum = lora.Einsum(shape) lora0 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2)) lora1 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, axes=(1, 2))) key = jax.random.key(0) x = jax.random.normal(key, (8, 64, 32)) # (BSD) eqn = "BSD,3KDH->3BSKH" # Ensure that lora parameters are not initialized when LoRA is not used. params = einsum.init(key, eqn, x) assert "lora_a" not in params["params"] assert "lora_b" not in params["params"] # Check that default axes work. params_lora0 = lora0.init(key, eqn, x) assert params_lora0["params"]["lora_a"].shape == (3, 8, 32, 2) assert params_lora0["params"]["lora_b"].shape == (3, 8, 2, 4) # Check that user provided axes work. params_lora1 = lora1.init(key, eqn, x) assert params_lora1["params"]["lora_a"].shape == (3, 8, 2, 4) assert params_lora1["params"]["lora_b"].shape == (3, 2, 32, 4) def test_lora_einsum_same_output(): shape = (3, 8, 32, 4) # (3KDH) einsum = lora.Einsum(shape) einsum_lora = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros)) key = jax.random.key(0) x = jax.random.normal(key, (8, 64, 32)) # (BSD) eqn = "BSD,3KDH->3BSKH" params = einsum.init(key, eqn, x) output = einsum.apply(params, eqn, x) params_lora = einsum_lora.init(key, eqn, x) output_lora = einsum_lora.apply(params_lora, eqn, x) # Results are the same since the LoRA parameters are initialized to zeros. assert jnp.allclose(output, output_lora) def test_lora_ffn_params_shape(): ffn = lora.FeedForward(features=8, hidden_dim=32) ffn_lora = lora.FeedForward( features=8, hidden_dim=32, lora_config=lora.LoRAConfig(rank=2), ) key = jax.random.key(0) x = jax.random.normal(key, (2, 8)) params = ffn.init(key, x) assert params["params"]["gating_einsum"].shape == (2, 8, 32) assert params["params"]["linear"].shape == (32, 8) params_lora = ffn_lora.init(key, x) assert params_lora["params"]["gating_einsum"].shape == (2, 8, 32) assert params_lora["params"]["linear"].shape == (32, 8) assert params_lora["params"]["gating_einsum_lora_a"].shape == (2, 8, 2) assert params_lora["params"]["gating_einsum_lora_b"].shape == (2, 2, 32) assert params_lora["params"]["linear_lora_a"].shape == (32, 2) assert params_lora["params"]["linear_lora_b"].shape == (2, 8) def test_lora_ffn_same_output(): ffn = lora.FeedForward(features=8, hidden_dim=32) ffn_lora = lora.FeedForward( features=8, hidden_dim=32, lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros), ) key = jax.random.key(0) x = jax.random.normal(key, (2, 8)) params = ffn.init(key, x) output = ffn.apply(params, x) params_lora = ffn_lora.init(key, x) output_lora = ffn_lora.apply(params_lora, x) assert jnp.allclose(output, output_lora) ================================================ FILE: src/openpi/models/model.py ================================================ import abc from collections.abc import Sequence import dataclasses import enum import logging import pathlib from typing import Generic, TypeVar import augmax from flax import nnx from flax import struct from flax import traverse_util import jax import jax.numpy as jnp import numpy as np import orbax.checkpoint as ocp import safetensors import torch from openpi.models_pytorch import pi0_pytorch from openpi.shared import image_tools import openpi.shared.array_typing as at logger = logging.getLogger("openpi") # Type variable for array types (JAX arrays, PyTorch tensors, or numpy arrays) ArrayT = TypeVar("ArrayT", bound=jax.Array | torch.Tensor | np.ndarray) class ModelType(enum.Enum): """Supported model types.""" PI0 = "pi0" PI0_FAST = "pi0_fast" PI05 = "pi05" # The model always expects these images IMAGE_KEYS = ( "base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb", ) # This may need change if we release a small model. IMAGE_RESOLUTION = (224, 224) # Data format # # Data transforms produce the model input as a nested dictionary which is later converted # into `Obesrvation` and `Actions` objects. See below. # # In the dictory form, this data should look like: # { # # Observation data. # "image": { # "base_0_rgb": (float32|uint8)[*b, h, w, 3], # RGB image in [-1, 1] or [0, 255] # ... # Additional camera views # }, # "image_mask": { # "base_0_rgb": bool[*b], # True if image is valid # ... # Masks for additional views # }, # "state": float32[*b, s], # Low-dimensional robot state # "tokenized_prompt": int32[*b, l], # Optional, tokenized language prompt # "tokenized_prompt_mask": bool[*b, l], # Optional, mask for tokenized prompt # "token_ar_mask": int32[*b, l], # Optional, autoregressive mask for FAST model # "token_loss_mask": bool[*b, l], # Optional, loss mask for FAST model # # # Actions data. # "actions": float32[*b ah ad] # } # where: # *b = batch dimensions # h,w = image height/width # s = state dimension # l = sequence length # @at.typecheck @struct.dataclass class Observation(Generic[ArrayT]): """Holds observations, i.e., inputs to the model. See `Observation.from_dict` to see the expected dictionary form. This is the format that should be produced by the data transforms. """ # Images, in [-1, 1] float32. images: dict[str, at.Float[ArrayT, "*b h w c"]] # Image masks, with same keys as images. image_masks: dict[str, at.Bool[ArrayT, "*b"]] # Low-dimensional robot state. state: at.Float[ArrayT, "*b s"] # Tokenized prompt. tokenized_prompt: at.Int[ArrayT, "*b l"] | None = None # Tokenized prompt mask. tokenized_prompt_mask: at.Bool[ArrayT, "*b l"] | None = None # pi0-fast model specific fields. # Token auto-regressive mask (for FAST autoregressive model). token_ar_mask: at.Int[ArrayT, "*b l"] | None = None # Token loss mask (for FAST autoregressive model). token_loss_mask: at.Bool[ArrayT, "*b l"] | None = None @classmethod def from_dict(cls, data: at.PyTree[ArrayT]) -> "Observation[ArrayT]": """This method defines the mapping between unstructured data (i.e., nested dict) to the structured Observation format.""" # Ensure that tokenized_prompt and tokenized_prompt_mask are provided together. if ("tokenized_prompt" in data) != ("tokenized_prompt_mask" in data): raise ValueError("tokenized_prompt and tokenized_prompt_mask must be provided together.") # If images are uint8, convert them to [-1, 1] float32. for key in data["image"]: if data["image"][key].dtype == np.uint8: data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0 elif hasattr(data["image"][key], "dtype") and data["image"][key].dtype == torch.uint8: data["image"][key] = data["image"][key].to(torch.float32).permute(0, 3, 1, 2) / 255.0 * 2.0 - 1.0 return cls( images=data["image"], image_masks=data["image_mask"], state=data["state"], tokenized_prompt=data.get("tokenized_prompt"), tokenized_prompt_mask=data.get("tokenized_prompt_mask"), token_ar_mask=data.get("token_ar_mask"), token_loss_mask=data.get("token_loss_mask"), ) def to_dict(self) -> at.PyTree[ArrayT]: """Convert the Observation to a nested dict.""" result = dataclasses.asdict(self) result["image"] = result.pop("images") result["image_mask"] = result.pop("image_masks") return result # Defines the format of the actions. This field is included as "actions" inside the dictionary # produced by the data transforms. Actions = at.Float[ArrayT, "*b ah ad"] def preprocess_observation( rng: at.KeyArrayLike | None, observation: Observation, *, train: bool = False, image_keys: Sequence[str] = IMAGE_KEYS, image_resolution: tuple[int, int] = IMAGE_RESOLUTION, ) -> Observation: """Preprocess the observations by performing image augmentations (if train=True), resizing (if necessary), and filling in a default image mask (if necessary). """ if not set(image_keys).issubset(observation.images): raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}") batch_shape = observation.state.shape[:-1] out_images = {} for key in image_keys: image = observation.images[key] if image.shape[1:3] != image_resolution: logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}") image = image_tools.resize_with_pad(image, *image_resolution) if train: # Convert from [-1, 1] to [0, 1] for augmax. image = image / 2.0 + 0.5 transforms = [] if "wrist" not in key: height, width = image.shape[1:3] transforms += [ augmax.RandomCrop(int(width * 0.95), int(height * 0.95)), augmax.Resize(width, height), augmax.Rotate((-5, 5)), ] transforms += [ augmax.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5), ] sub_rngs = jax.random.split(rng, image.shape[0]) image = jax.vmap(augmax.Chain(*transforms))(sub_rngs, image) # Back to [-1, 1]. image = image * 2.0 - 1.0 out_images[key] = image # obtain mask out_masks = {} for key in out_images: if key not in observation.image_masks: # do not mask by default out_masks[key] = jnp.ones(batch_shape, dtype=jnp.bool) else: out_masks[key] = jnp.asarray(observation.image_masks[key]) return Observation( images=out_images, image_masks=out_masks, state=observation.state, tokenized_prompt=observation.tokenized_prompt, tokenized_prompt_mask=observation.tokenized_prompt_mask, token_ar_mask=observation.token_ar_mask, token_loss_mask=observation.token_loss_mask, ) @dataclasses.dataclass(frozen=True) class BaseModelConfig(abc.ABC): """Configuration shared by all models. Specific models should inherit from this class, and implement the `create` method to create the corresponding model. """ # Action space dimension. action_dim: int # Action sequence length. action_horizon: int # Tokenized prompt maximum length. max_token_len: int @property @abc.abstractmethod def model_type(self) -> ModelType: """The model type.""" @abc.abstractmethod def create(self, rng: at.KeyArrayLike) -> "BaseModel": """Create a new model, initializing parameters.""" def load(self, params: at.Params, *, remove_extra_params: bool = True) -> "BaseModel": """Create a model with the given parameters.""" model = nnx.eval_shape(self.create, jax.random.key(0)) graphdef, state = nnx.split(model) if remove_extra_params: params = ocp.transform_utils.intersect_trees(state.to_pure_dict(), params) at.check_pytree_equality(expected=state.to_pure_dict(), got=params, check_shapes=True, check_dtypes=False) state.replace_by_pure_dict(params) return nnx.merge(graphdef, state) def load_pytorch(self, train_config, weight_path: str): logger.info(f"train_config: {train_config}") model = pi0_pytorch.PI0Pytorch(config=train_config.model) safetensors.torch.load_model(model, weight_path) return model @abc.abstractmethod def inputs_spec(self, *, batch_size: int = 1) -> tuple[Observation, Actions]: """Returns the input specification for the model. Values are jax.ShapeDtypeStruct.""" def fake_obs(self, batch_size: int = 1) -> Observation: observation_spec, _ = self.inputs_spec(batch_size=batch_size) return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), observation_spec) def fake_act(self, batch_size: int = 1) -> Actions: _, action_spec = self.inputs_spec(batch_size=batch_size) return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), action_spec) @dataclasses.dataclass class BaseModel(nnx.Module, abc.ABC): """Base class for all model implementations. Specific models should inherit from this class. They should call super().__init__() to initialize the shared attributes (action_dim, action_horizon, and max_token_len). """ action_dim: int action_horizon: int max_token_len: int @abc.abstractmethod def compute_loss( self, rng: at.KeyArrayLike, observation: Observation, actions: Actions, *, train: bool = False, ) -> at.Float[at.Array, "*b ah"]: ... @abc.abstractmethod def sample_actions(self, rng: at.KeyArrayLike, observation: Observation, **kwargs) -> Actions: ... def restore_params( params_path: pathlib.Path | str, *, restore_type: type[np.ndarray] | type[jax.Array] = jax.Array, dtype: jnp.dtype | None = None, sharding: jax.sharding.Sharding | None = None, ) -> at.Params: """Restores unstructured params PyTree from a checkpoint. This works with checkpoints saved with `save_state` during openpi training (see `training/checkpoints.py`) as well as pre-trained checkpoints released for openpi. Args: params_path: The local path to the checkpoint directory. restore_type: The type to restore the params as. Can be set to `np.ndarray` to load the params as a numpy array. dtype: The dtype to restore all params as. If not provided, will use the original dtype from the checkpoint. sharding: The sharding to use for the params. If not provided, the params will be replicated across all devices. Returns: The restored params. """ params_path = pathlib.Path(params_path).resolve() if not str(params_path).startswith("gs://") else params_path if restore_type is jax.Array and sharding is None: mesh = jax.sharding.Mesh(jax.devices(), ("x",)) sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) with ocp.PyTreeCheckpointer() as ckptr: metadata = ckptr.metadata(params_path) item = {"params": metadata["params"]} params = ckptr.restore( params_path, ocp.args.PyTreeRestore( item=item, restore_args=jax.tree.map( lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=restore_type, dtype=dtype), item ), ), )["params"] # If the params were saved with `save_state` during openpi training, every key path will end with "value", which is # added by `nnx.State`. We remove the "value" suffix here and always return what NNX calls a "pure dict". flat_params = traverse_util.flatten_dict(params) if all(kp[-1] == "value" for kp in flat_params): flat_params = {kp[:-1]: v for kp, v in flat_params.items()} return traverse_util.unflatten_dict(flat_params) ================================================ FILE: src/openpi/models/model_test.py ================================================ from flax import nnx import jax import pytest from openpi.models import model as _model from openpi.models import pi0_config from openpi.models import pi0_fast from openpi.shared import download from openpi.shared import nnx_utils def test_pi0_model(): key = jax.random.key(0) config = pi0_config.Pi0Config() model = config.create(key) batch_size = 2 obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) assert loss.shape == (batch_size, config.action_horizon) actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10) assert actions.shape == (batch_size, model.action_horizon, model.action_dim) def test_pi0_lora_model(): key = jax.random.key(0) config = pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora") model = config.create(key) batch_size = 2 obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) assert loss.shape == (batch_size, config.action_horizon) actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10) assert actions.shape == (batch_size, model.action_horizon, model.action_dim) def test_pi0_fast_model(): key = jax.random.key(0) config = pi0_fast.Pi0FASTConfig() model = config.create(key) batch_size = 2 obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) assert loss.shape == (batch_size,) actions = nnx_utils.module_jit(model.sample_actions)(key, obs) assert actions.shape == (batch_size, 256) def test_pi0_fast_lora_model(): key = jax.random.key(0) config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora") model = config.create(key) batch_size = 2 obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act) assert loss.shape == (batch_size,) actions = nnx_utils.module_jit(model.sample_actions)(key, obs) assert actions.shape == (batch_size, 256) lora_filter = nnx_utils.PathRegex(".*lora.*") model_state = nnx.state(model) lora_state_elems = list(model_state.filter(lora_filter)) assert len(lora_state_elems) > 0 @pytest.mark.manual def test_model_restore(): key = jax.random.key(0) config = pi0_config.Pi0Config() batch_size = 2 obs, act = config.fake_obs(batch_size), config.fake_act(batch_size) model = config.load( _model.restore_params(download.maybe_download("gs://openpi-assets/checkpoints/pi0_base/params")) ) loss = model.compute_loss(key, obs, act) assert loss.shape == (batch_size, config.action_horizon) actions = model.sample_actions(key, obs, num_steps=10) assert actions.shape == (batch_size, model.action_horizon, model.action_dim) ================================================ FILE: src/openpi/models/pi0.py ================================================ import logging import einops import flax.nnx as nnx import flax.nnx.bridge as nnx_bridge import jax import jax.numpy as jnp from typing_extensions import override from openpi.models import model as _model from openpi.models import pi0_config import openpi.models.gemma as _gemma import openpi.models.siglip as _siglip from openpi.shared import array_typing as at logger = logging.getLogger("openpi") def make_attn_mask(input_mask, mask_ar): """Adapted from big_vision. Tokens can attend to valid inputs tokens which have a cumulative mask_ar smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to setup several types of attention, for example: [[1 1 1 1 1 1]]: pure causal attention. [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between themselves and the last 3 tokens have a causal attention. The first entry could also be a 1 without changing behaviour. [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a block can attend all previous blocks and all tokens on the same block. Args: input_mask: bool[B, N] true if its part of the input, false if padding. mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on it and false where it shares the same attention mask as the previous token. """ mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape) cumsum = jnp.cumsum(mask_ar, axis=1) attn_mask = cumsum[:, None, :] <= cumsum[:, :, None] valid_mask = input_mask[:, None, :] * input_mask[:, :, None] return jnp.logical_and(attn_mask, valid_mask) @at.typecheck def posemb_sincos( pos: at.Real[at.Array, " b"], embedding_dim: int, min_period: float, max_period: float ) -> at.Float[at.Array, "b {embedding_dim}"]: """Computes sine-cosine positional embedding vectors for scalar positions.""" if embedding_dim % 2 != 0: raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2") fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2) period = min_period * (max_period / min_period) ** fraction sinusoid_input = jnp.einsum( "i,j->ij", pos, 1.0 / period * 2 * jnp.pi, precision=jax.lax.Precision.HIGHEST, ) return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1) class Pi0(_model.BaseModel): def __init__(self, config: pi0_config.Pi0Config, rngs: nnx.Rngs): super().__init__(config.action_dim, config.action_horizon, config.max_token_len) self.pi05 = config.pi05 paligemma_config = _gemma.get_config(config.paligemma_variant) action_expert_config = _gemma.get_config(config.action_expert_variant) # TODO: rewrite gemma in NNX. For now, use bridge. llm = nnx_bridge.ToNNX( _gemma.Module( configs=[paligemma_config, action_expert_config], embed_dtype=config.dtype, adarms=config.pi05, ) ) llm.lazy_init(rngs=rngs, method="init", use_adarms=[False, True] if config.pi05 else [False, False]) img = nnx_bridge.ToNNX( _siglip.Module( num_classes=paligemma_config.width, variant="So400m/14", pool_type="none", scan=True, dtype_mm=config.dtype, ) ) img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs) self.PaliGemma = nnx.Dict(llm=llm, img=img) self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs) if config.pi05: self.time_mlp_in = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs) self.time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs) else: self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs) self.action_time_mlp_in = nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngs=rngs) self.action_time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs) self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs) # This attribute gets automatically set by model.train() and model.eval(). self.deterministic = True @at.typecheck def embed_prefix( self, obs: _model.Observation ) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]: input_mask = [] ar_mask = [] tokens = [] # embed images for name in obs.images: image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False) tokens.append(image_tokens) input_mask.append( einops.repeat( obs.image_masks[name], "b -> b s", s=image_tokens.shape[1], ) ) # image tokens attend to each other ar_mask += [False] * image_tokens.shape[1] # add language (aka tokenized inputs) if obs.tokenized_prompt is not None: tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed") tokens.append(tokenized_inputs) input_mask.append(obs.tokenized_prompt_mask) # full attention between image and language inputs ar_mask += [False] * tokenized_inputs.shape[1] tokens = jnp.concatenate(tokens, axis=1) input_mask = jnp.concatenate(input_mask, axis=1) ar_mask = jnp.array(ar_mask) return tokens, input_mask, ar_mask @at.typecheck def embed_suffix( self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, " b"] ) -> tuple[ at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"], at.Float[at.Array, "b emb"] | None, ]: input_mask = [] ar_mask = [] tokens = [] if not self.pi05: # add a single state token state_token = self.state_proj(obs.state)[:, None, :] tokens.append(state_token) input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_)) # image/language inputs do not attend to state or actions ar_mask += [True] action_tokens = self.action_in_proj(noisy_actions) # embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0) if self.pi05: # time MLP (for adaRMS) time_emb = self.time_mlp_in(time_emb) time_emb = nnx.swish(time_emb) time_emb = self.time_mlp_out(time_emb) time_emb = nnx.swish(time_emb) action_expert_tokens = action_tokens adarms_cond = time_emb else: # mix timestep + action information using an MLP (no adaRMS) time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon) action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1) action_time_tokens = self.action_time_mlp_in(action_time_tokens) action_time_tokens = nnx.swish(action_time_tokens) action_time_tokens = self.action_time_mlp_out(action_time_tokens) action_expert_tokens = action_time_tokens adarms_cond = None tokens.append(action_expert_tokens) input_mask.append(jnp.ones(action_expert_tokens.shape[:2], dtype=jnp.bool_)) # image/language/state inputs do not attend to action tokens ar_mask += [True] + ([False] * (self.action_horizon - 1)) tokens = jnp.concatenate(tokens, axis=1) input_mask = jnp.concatenate(input_mask, axis=1) ar_mask = jnp.array(ar_mask) return tokens, input_mask, ar_mask, adarms_cond @override def compute_loss( self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False ) -> at.Float[at.Array, "*b ah"]: preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3) observation = _model.preprocess_observation(preprocess_rng, observation, train=train) batch_shape = actions.shape[:-2] noise = jax.random.normal(noise_rng, actions.shape) time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001 time_expanded = time[..., None, None] x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions # one big forward pass of prefix + suffix at once prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation) suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(observation, x_t, time) input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1) ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0) attn_mask = make_attn_mask(input_mask, ar_mask) positions = jnp.cumsum(input_mask, axis=1) - 1 (prefix_out, suffix_out), _ = self.PaliGemma.llm( [prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions, adarms_cond=[None, adarms_cond] ) v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :]) return jnp.mean(jnp.square(v_t - u_t), axis=-1) @override def sample_actions( self, rng: at.KeyArrayLike, observation: _model.Observation, *, num_steps: int | at.Int[at.Array, ""] = 10, noise: at.Float[at.Array, "b ah ad"] | None = None, ) -> _model.Actions: observation = _model.preprocess_observation(None, observation, train=False) # note that we use the convention more common in diffusion literature, where t=1 is noise and t=0 is the target # distribution. yes, this is the opposite of the pi0 paper, and I'm sorry. dt = -1.0 / num_steps batch_size = observation.state.shape[0] if noise is None: noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim)) # first fill KV cache with a forward pass of the prefix prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation) prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask) positions = jnp.cumsum(prefix_mask, axis=1) - 1 _, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions) def step(carry): x_t, time = carry suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix( observation, x_t, jnp.broadcast_to(time, batch_size) ) # `suffix_attn_mask` is shape (b, suffix_len, suffix_len) indicating how the suffix tokens can attend to each # other suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask) # `prefix_attn_mask` is shape (b, suffix_len, prefix_len) indicating how the suffix tokens can attend to the # prefix tokens prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1]) # `combined_mask` is shape (b, suffix_len, prefix_len + suffix_len) indicating how the suffix tokens (which # generate the queries) can attend to the full prefix + suffix sequence (which generates the keys and values) full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1) assert full_attn_mask.shape == ( batch_size, suffix_tokens.shape[1], prefix_tokens.shape[1] + suffix_tokens.shape[1], ) # `positions` is shape (b, suffix_len) indicating the positions of the suffix tokens positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1 (prefix_out, suffix_out), _ = self.PaliGemma.llm( [None, suffix_tokens], mask=full_attn_mask, positions=positions, kv_cache=kv_cache, adarms_cond=[None, adarms_cond], ) assert prefix_out is None v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :]) return x_t + dt * v_t, time + dt def cond(carry): x_t, time = carry # robust to floating-point error return time >= -dt / 2 x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0)) return x_0 ================================================ FILE: src/openpi/models/pi0_config.py ================================================ import dataclasses from typing import TYPE_CHECKING import flax.nnx as nnx import jax import jax.numpy as jnp from typing_extensions import override from openpi.models import model as _model import openpi.models.gemma as _gemma from openpi.shared import array_typing as at import openpi.shared.nnx_utils as nnx_utils if TYPE_CHECKING: from openpi.models.pi0 import Pi0 @dataclasses.dataclass(frozen=True) class Pi0Config(_model.BaseModelConfig): dtype: str = "bfloat16" paligemma_variant: _gemma.Variant = "gemma_2b" action_expert_variant: _gemma.Variant = "gemma_300m" # Set the model specific defaults. action_dim: int = 32 action_horizon: int = 50 max_token_len: int = None # type: ignore # Pi05 has two differences from Pi0: # - the state input is part of the discrete language tokens rather than a continuous input that is part of the suffix # - the action expert uses adaRMSNorm to inject the flow matching timestep pi05: bool = False # This config option is not used directly by the model, but it is read by the ModelTransformFactory. discrete_state_input: bool = None # type: ignore pytorch_compile_mode: str | None = "max-autotune" def __post_init__(self): if self.max_token_len is None: object.__setattr__(self, "max_token_len", 200 if self.pi05 else 48) if self.discrete_state_input is None: object.__setattr__(self, "discrete_state_input", self.pi05) if self.pytorch_compile_mode is not None: assert self.pytorch_compile_mode in [ "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs", ] @property @override def model_type(self) -> _model.ModelType: if self.pi05: return _model.ModelType.PI05 return _model.ModelType.PI0 @override def create(self, rng: at.KeyArrayLike) -> "Pi0": from openpi.models.pi0 import Pi0 return Pi0(self, rngs=nnx.Rngs(rng)) @override def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]: image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32) image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_) with at.disable_typechecking(): observation_spec = _model.Observation( images={ "base_0_rgb": image_spec, "left_wrist_0_rgb": image_spec, "right_wrist_0_rgb": image_spec, }, image_masks={ "base_0_rgb": image_mask_spec, "left_wrist_0_rgb": image_mask_spec, "right_wrist_0_rgb": image_mask_spec, }, state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32), tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32), tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool), ) action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32) return observation_spec, action_spec def get_freeze_filter(self) -> nnx.filterlib.Filter: """Returns the freeze filter based on the model config.""" filters = [] has_lora = False gemma_params_filter = nnx_utils.PathRegex(".*llm.*") action_expert_params_filter = nnx_utils.PathRegex(".*llm.*_1.*") if "lora" in self.paligemma_variant: filters.append( gemma_params_filter, ) if "lora" not in self.action_expert_variant: # If only freeze gemma params, exclude action expert params. filters.append( nnx.Not(action_expert_params_filter), ) has_lora = True elif "lora" in self.action_expert_variant: filters.append( action_expert_params_filter, ) has_lora = True if has_lora: # If any lora is used, exclude all lora params. filters.append( nnx.Not(nnx_utils.PathRegex(".*lora.*")), ) if not filters: return nnx.Nothing return nnx.All(*filters) ================================================ FILE: src/openpi/models/pi0_fast.py ================================================ import dataclasses import logging from typing import Any import einops import flax.nnx as nnx import flax.nnx.bridge as nnx_bridge import jax import jax.numpy as jnp from typing_extensions import override from openpi.models import model as _model import openpi.models.gemma_fast as _gemma import openpi.models.siglip as _siglip from openpi.shared import array_typing as at import openpi.shared.nnx_utils as nnx_utils logger = logging.getLogger("openpi") PALIGEMMA_EOS_TOKEN = 1 def make_attn_mask(input_mask, mask_ar): """Adapted from big_vision. Tokens can attend to valid inputs tokens which have a cumulative mask_ar smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to setup several types of attention, for example: [[1 1 1 1 1 1]]: pure causal attention. [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between themselves and the last 3 tokens have a causal attention. The first entry could also be a 1 without changing behaviour. [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a block can attend all previous blocks and all tokens on the same block. Args: input_mask: bool[B, N] true if its part of the input, false if padding. mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on it and false where it shares the same attention mask as the previous token. """ mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape) cumsum = jnp.cumsum(mask_ar, axis=1) attn_mask = cumsum[:, None, :] <= cumsum[:, :, None] valid_mask = input_mask[:, None, :] * input_mask[:, :, None] return jnp.logical_and(attn_mask, valid_mask) @jax.vmap def left_to_right_align(x, input_mask, attn_mask): """Converts input from left-align to right-aligned.""" # Due to vmap, this is operating in a single example (not batch level). assert x.ndim == 2 assert input_mask.ndim == 1 assert attn_mask.ndim == 2 assert x.shape[0] == input_mask.shape[0] assert attn_mask.shape[0] == attn_mask.shape[1], attn_mask.shape seqlen = jnp.max(input_mask * jnp.arange(input_mask.shape[0])) + 1 x = jnp.roll(x, -seqlen, axis=0) input_mask = jnp.roll(input_mask, -seqlen, axis=0) attn_mask = jnp.roll(attn_mask, -seqlen, axis=(0, 1)) return x, input_mask, attn_mask def put_along_last_axis(arr, indices, values): """Like np.put_along_axis(..., axis=-1), since jax is missing it.""" assert arr.ndim == indices.ndim == values.ndim, (arr.ndim, indices.ndim, values.ndim) onehot = jax.nn.one_hot(indices, arr.shape[-1], dtype=values.dtype) put_mask = jnp.einsum("...i,...in->...n", jnp.ones(values.shape, jnp.int32), onehot) put_values = jnp.einsum("...i,...in->...n", values, onehot) return jnp.where(put_mask, put_values, arr) @dataclasses.dataclass(frozen=True) class Pi0FASTConfig(_model.BaseModelConfig): dtype: str = "bfloat16" paligemma_variant: _gemma.Variant = "gemma_2b" # Set the model specific defaults. action_dim: int = 32 action_horizon: int = 32 max_token_len: int = 250 # Tokenizer for the fast model. fast_model_tokenizer: Any | None = None # Keyword arguments for the fast model tokenizer. fast_model_tokenizer_kwargs: dict[str, Any] | None = None @property @override def model_type(self) -> _model.ModelType: return _model.ModelType.PI0_FAST @override def create(self, rng: at.KeyArrayLike) -> "Pi0FAST": return Pi0FAST(self, rngs=nnx.Rngs(rng)) @override def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]: image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32) image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_) with at.disable_typechecking(): observation_spec = _model.Observation( images={ "base_0_rgb": image_spec, "base_1_rgb": image_spec, "wrist_0_rgb": image_spec, }, image_masks={ "base_0_rgb": image_mask_spec, "base_1_rgb": image_mask_spec, "wrist_0_rgb": image_mask_spec, }, state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32), tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32), tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool), token_ar_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32), token_loss_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.bool_), ) action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32) return observation_spec, action_spec def get_freeze_filter(self) -> nnx.filterlib.Filter: """Returns the freeze filter based on the model config.""" if "lora" in self.paligemma_variant: return nnx.All(nnx_utils.PathRegex(".*llm.*"), nnx.Not(nnx_utils.PathRegex(".*lora.*"))) return nnx.Nothing class Pi0FAST(_model.BaseModel): def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs): super().__init__(config.action_dim, config.action_horizon, config.max_token_len) paligemma_config = _gemma.get_config(config.paligemma_variant) # TODO: rewrite gemma in NNX. For now, use bridge. llm = nnx_bridge.ToNNX( _gemma.Module( **paligemma_config, embed_dtype=config.dtype, cache_dtype=config.dtype, ) ) llm.lazy_init(rngs=rngs, method="init") img = nnx_bridge.ToNNX( _siglip.Module( num_classes=paligemma_config.width, variant="So400m/14", pool_type="none", scan=True, dtype_mm=config.dtype, ) ) img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs) self.PaliGemma = nnx.Dict(llm=llm, img=img) @at.typecheck def embed_inputs( self, obs: _model.Observation ) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Int[at.Array, "b s"]]: input_mask = [] ar_mask = [] token_embeddings = [] # embed images for name in obs.images: image_token_embeddings, _ = self.PaliGemma.img(obs.images[name], train=False) token_embeddings.append(image_token_embeddings) input_mask.append( einops.repeat( obs.image_masks[name], "b -> b s", s=image_token_embeddings.shape[1], ) ) # image tokens attend to each other --> AR mask = 0 ar_mask.append(0 * input_mask[-1]) # add tokenized inputs assert obs.tokenized_prompt is not None, "Tokenized prompt is required" assert obs.tokenized_prompt_mask is not None, "Tokenized prompt mask is required" assert obs.token_ar_mask is not None, "Token auto-regressive mask is required" tokenized_inputs_embeddings = self.PaliGemma.llm(obs.tokenized_prompt, embed_only=True) token_embeddings.append(tokenized_inputs_embeddings) input_mask.append(obs.tokenized_prompt_mask) ar_mask.append(obs.token_ar_mask) # return embeddings, input mask, and ar mask return ( jnp.concatenate(token_embeddings, axis=1), jnp.concatenate(input_mask, axis=1), jnp.concatenate(ar_mask, axis=1), ) @override def compute_loss( self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False ) -> at.Float[at.Array, "*b ah"]: observation = _model.preprocess_observation( rng, observation, train=train, image_keys=list(observation.images.keys()) ) # Compute inputs: one big forward pass of prefix + suffix at once input_token_embeddings, input_mask, ar_mask = self.embed_inputs(observation) attn_mask = make_attn_mask(input_mask, ar_mask) # Compute one-hot targets: we predict *next* token, so shift the input tokens by one. targets = jax.nn.one_hot( observation.tokenized_prompt[:, 1:], self.PaliGemma.llm.module.vocab_size, ) # Each input predicts *next* token, so we don't input the last token. pre_logits, _, _ = self.PaliGemma.llm( embedded_prefix=input_token_embeddings[:, :-1], mask=attn_mask[:, :-1, :-1], return_prelogits=True, ) # Only decode logits for the target tokens to save memory # (decoding matmul is large because it is a seq_len x vocab_size dense layer). logits, _ = self.PaliGemma.llm( pre_logits=pre_logits[:, -targets.shape[1] :], ) logp = jax.nn.log_softmax(logits, axis=-1) # Compute CE loss on token targets assert observation.token_loss_mask is not None, "Token loss mask is required" loss_mask = observation.token_loss_mask[:, 1:] token_pplx = jnp.sum(targets * logp, axis=-1) return -jnp.sum(token_pplx * loss_mask, axis=-1) / jnp.clip(jnp.sum(loss_mask, -1), 1) @override def sample_actions( self, rng: at.KeyArrayLike, observation: _model.Observation, *, max_decoding_steps: int | at.Int[at.Array, ""] = 256, temperature: float = 0.0, ) -> _model.Actions: # TODO: this is a hack to get the image keys. observation = _model.preprocess_observation( None, observation, train=False, image_keys=list(observation.images.keys()) ) # embed inputs prefix_token_embeddings, prefix_mask, prefix_ar_mask = self.embed_inputs(observation) prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask) # left to right align all input token sequences prefix_token_embeddings, prefix_mask, prefix_attn_mask = left_to_right_align( prefix_token_embeddings, prefix_mask, prefix_attn_mask ) prefill_size = prefix_token_embeddings.shape[1] prefill_len = jnp.sum(prefix_mask, axis=-1) prefix_start = prefill_size - prefill_len # first fill KV cache with a forward pass of the prefix # pad attention mask to set the size of the KV cache (prefill_size + max_decoding_steps) prefix_attn_mask = jnp.pad(prefix_attn_mask, ((0, 0), (0, 0), (0, max_decoding_steps))) prefix_positions = jnp.cumsum(prefix_mask, axis=-1) - 1 prefix_logits, kv_cache, _ = self.PaliGemma.llm( embedded_prefix=prefix_token_embeddings, mask=prefix_attn_mask, positions=prefix_positions, decode=True ) # prepare decoding -- final logit decodes the first token last_logit = prefix_logits[:, -1:] output_tokens = jnp.zeros((last_logit.shape[0], max_decoding_steps)) def step(carry): rng, last_logit, output_tokens, cache, _, step = carry # Sample token from last logit # Split RNG for this step rng, rng_step = jax.random.split(rng) token = jax.lax.cond( temperature > 0.0, lambda _: jax.random.categorical(rng_step, last_logit / temperature, axis=-1), lambda _: jnp.argmax(last_logit, axis=-1), operand=None, ) output_tokens = put_along_last_axis(output_tokens, jnp.broadcast_to(step, (token.shape[0], 1)), token) # Check for early stopping --> stop if all batch elements have EOS token has_eos = jnp.any(token == PALIGEMMA_EOS_TOKEN, axis=-1) all_eos = jnp.all(has_eos) # Decode one step token_embedding = self.PaliGemma.llm(token, embed_only=True) positions = prefill_len[:, None] + step + 1 mask = jnp.logical_and( jnp.arange(prefill_size + max_decoding_steps)[None, None, :] >= prefix_start[:, None, None], jnp.arange(prefill_size + max_decoding_steps)[None, None, :] < (jnp.broadcast_to(prefill_size + step + 1, (prefix_start.shape[0], 1, 1))), ) last_logit, kv_cache, _ = self.PaliGemma.llm( embedded_prefix=token_embedding, mask=mask, positions=positions, decode=True, kv_cache=cache ) return rng, last_logit, output_tokens, kv_cache, all_eos, step + 1 def cond(carry): _, _, _, _, all_eos, step = carry return (~all_eos) & (step < max_decoding_steps) # Use lax.while_loop so we can jit the full decoding loop. _, _, output_tokens, _, _, _ = jax.lax.while_loop( cond, step, (rng, last_logit, output_tokens, kv_cache, False, 0) ) return output_tokens ================================================ FILE: src/openpi/models/pi0_test.py ================================================ import flax.nnx as nnx import jax import openpi.models.pi0_config as _pi0_config def _get_frozen_state(config: _pi0_config.Pi0Config) -> nnx.State: abstract_model = nnx.eval_shape(config.create, jax.random.key(0)) freeze_filter = config.get_freeze_filter() return nnx.state(abstract_model, nnx.All(nnx.Param, freeze_filter)).flat_state() def test_pi0_full_finetune(): config = _pi0_config.Pi0Config() state = _get_frozen_state(config) assert len(state) == 0 def test_pi0_gemma_lora(): config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora") state = _get_frozen_state(config) assert len(state) == 9 assert all("lora" not in p for p in state) assert all("llm" in p for p in state) assert all("_1" not in p for p in state) def test_pi0_action_expert_lora(): config = _pi0_config.Pi0Config(action_expert_variant="gemma_300m_lora") state = _get_frozen_state(config) # excluding embedder, rest of the params should be same as gemma_lora. assert len(state) == 8 assert all("lora" not in p for p in state) assert all("llm" in p for p in state) # all frozen params should have _1 in their path since it's the action expert. assert all(any("_1" in p for p in path) for path in state) def test_pi0_all_lora(): config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora") state = _get_frozen_state(config) # sum of gemma_lora and action_expert_lora's frozen params. assert len(state) == 17 assert all("lora" not in p for p in state) assert all("llm" in p for p in state) ================================================ FILE: src/openpi/models/siglip.py ================================================ # Copyright 2024 Big Vision Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A refactored and simplified ViT adoptation for Pi, taken from big_vision.""" from collections.abc import Sequence import flax.linen as nn import jax import jax.numpy as jnp import numpy as np import openpi.training.sharding as sharding def posemb_sincos_2d(h, w, width, temperature=10_000.0, dtype=jnp.float32): """Follows the MoCo v3 logic.""" y, x = jnp.mgrid[:h, :w] assert width % 4 == 0, "Width must be mult of 4 for sincos posemb" omega = jnp.arange(width // 4) / (width // 4 - 1) omega = 1.0 / (temperature**omega) y = jnp.einsum("m,d->md", y.flatten(), omega) x = jnp.einsum("m,d->md", x.flatten(), omega) pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1) return jnp.asarray(pe, dtype)[None, :, :] def get_posemb(self, typ, seqshape, width, name, dtype=jnp.float32): if typ == "learn": return self.param( name, nn.initializers.normal(stddev=1 / np.sqrt(width)), (1, np.prod(seqshape), width), dtype, ) if typ == "sincos2d": return posemb_sincos_2d(*seqshape, width, dtype=dtype) raise ValueError(f"Unknown posemb type: {typ}") class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" mlp_dim: int | None = None # Defaults to 4x input dim dropout: float = 0.0 dtype_mm: str = "float32" @nn.compact def __call__(self, x, deterministic=True): # noqa: FBT002 """Applies Transformer MlpBlock module.""" inits = { "kernel_init": nn.initializers.xavier_uniform(), "bias_init": nn.initializers.normal(stddev=1e-6), } _, _, d = x.shape # n,l,d x = nn.Dense(self.mlp_dim or 4 * d, dtype=self.dtype_mm, **inits)(x) x = nn.gelu(x) x = nn.Dropout(rate=self.dropout)(x, deterministic) return nn.Dense(d, dtype=self.dtype_mm, **inits)(x) class Encoder1DBlock(nn.Module): """Single transformer encoder block (MHSA + MLP).""" mlp_dim: int | None = None # Defaults to 4x input dim num_heads: int = 12 dropout: float = 0.0 dtype_mm: str = "float32" @nn.compact def __call__(self, x, deterministic=True): # noqa: FBT002 out = {} x = sharding.activation_sharding_constraint(x) y = nn.LayerNorm(dtype=self.dtype_mm)(x) y = out["sa"] = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=deterministic, dtype=self.dtype_mm, )(y, y) y = sharding.activation_sharding_constraint(y) y = nn.Dropout(rate=self.dropout)(y, deterministic) x = out["+sa"] = x + y y = nn.LayerNorm(dtype=self.dtype_mm)(x) y = out["mlp"] = MlpBlock( mlp_dim=self.mlp_dim, dropout=self.dropout, dtype_mm=self.dtype_mm, )(y, deterministic) y = sharding.activation_sharding_constraint(y) y = nn.Dropout(rate=self.dropout)(y, deterministic) x = out["+mlp"] = x + y x = sharding.activation_sharding_constraint(x) return x, out class Encoder(nn.Module): """Transformer Model Encoder for sequence to sequence translation.""" depth: int mlp_dim: int | None = None # Defaults to 4x input dim num_heads: int = 12 dropout: float = 0.0 scan: bool = False remat_policy: str = "nothing_saveable" dtype_mm: str = "float32" @nn.compact def __call__(self, x, deterministic=True): # noqa: FBT002 out = {} if self.scan: block = nn.remat( Encoder1DBlock, prevent_cse=False, static_argnums=(2,), # 0=self, 2=deterministic policy=getattr(jax.checkpoint_policies, self.remat_policy, None), ) x, scan_out = nn.scan( block, variable_axes={"params": 0}, split_rngs={"params": True, "dropout": True}, in_axes=nn.broadcast, length=self.depth, )( name="encoderblock", dtype_mm=self.dtype_mm, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dropout=self.dropout, )(x, deterministic) for lyr in range(self.depth): out[f"block{lyr:02d}"] = jax.tree.map(lambda o, lyr=lyr: o[lyr], scan_out) else: # Input Encoder for lyr in range(self.depth): block_cur = Encoder1DBlock( name=f"encoderblock_{lyr}", dtype_mm=self.dtype_mm, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dropout=self.dropout, ) x, out[f"block{lyr:02d}"] = block_cur(x, deterministic) out["pre_ln"] = x # Alias for last block, but without the number in it. return nn.LayerNorm(name="encoder_norm", dtype=self.dtype_mm)(x), out class MAPHead(nn.Module): """Multihead Attention Pooling.""" mlp_dim: int | None = None # Defaults to 4x input dim num_heads: int = 12 dtype_mm: str = "float32" @nn.compact def __call__(self, x): n, _, d = x.shape # n,l,d probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, d), x.dtype) probe = jnp.tile(probe, [n, 1, 1]) x = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, dtype=self.dtype_mm, kernel_init=nn.initializers.xavier_uniform(), )(probe, x) y = nn.LayerNorm(dtype=self.dtype_mm)(x) x = x + MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype_mm)(y) return x[:, 0] class _Module(nn.Module): """ViT model.""" num_classes: int | None = None patch_size: Sequence[int] = (16, 16) width: int = 768 depth: int = 12 mlp_dim: int | None = None # Defaults to 4x input dim num_heads: int = 12 posemb: str = "learn" # Can also be "sincos2d" rep_size: int | bool = False dropout: float = 0.0 pool_type: str = "gap" # Can also be "map" or "tok" head_zeroinit: bool = True scan: bool = False # or "dots_with_no_batch_dims_saveable" for more speed (memory costly) remat_policy: str = "nothing_saveable" dtype_mm: str = "float32" @nn.compact def __call__(self, image, *, train=False): out = {} # Kevin edit: do patch extraction and posemb in float32, # because I feel like it's a bit safer. image = jnp.asarray(image, jnp.float32) # Patch extraction x = out["stem"] = nn.Conv( self.width, self.patch_size, strides=self.patch_size, padding="VALID", name="embedding", dtype=jnp.float32, )(image) n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # Add posemb before adding extra token. x = out["with_posemb"] = x + get_posemb(self, self.posemb, (h, w), c, "pos_embedding", jnp.float32) if self.pool_type == "tok": cls = self.param("cls", nn.initializers.zeros, (1, 1, c), x.dtype) x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1) n, _, c = x.shape # n,l,d x = nn.Dropout(rate=self.dropout)(x, not train) # Kevin edit: now cast back to dtype_mm (potentially half precision) x = x.astype(self.dtype_mm) x, out["encoder"] = Encoder( depth=self.depth, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dropout=self.dropout, scan=self.scan, remat_policy=self.remat_policy, dtype_mm=self.dtype_mm, name="Transformer", )(x, deterministic=not train) encoded = out["encoded"] = x if self.pool_type == "map": x = out["head_input"] = MAPHead( num_heads=self.num_heads, mlp_dim=self.mlp_dim, dtype=self.dtype_mm, )(x) elif self.pool_type == "gap": x = out["head_input"] = jnp.mean(x, axis=1) elif self.pool_type == "0": x = out["head_input"] = x[:, 0] elif self.pool_type == "tok": x = out["head_input"] = x[:, 0] encoded = encoded[:, 1:] elif self.pool_type == "none": pass else: raise ValueError(f"Unknown pool type: '{self.pool_type}'") x_2d = jnp.reshape(encoded, [n, h, w, -1]) if self.rep_size: rep_size = self.width if self.rep_size is True else self.rep_size hid = nn.Dense(rep_size, dtype=self.dtype_mm, name="pre_logits") # NOTE: In the past we did not include tanh in pre_logits. # For few-shot, it should not matter much, as it whitens anyways. x_2d = nn.tanh(hid(x_2d)) x = nn.tanh(hid(x)) out["pre_logits_2d"] = x_2d out["pre_logits"] = x if self.num_classes: kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {} head = nn.Dense(self.num_classes, dtype=self.dtype_mm, name="head", **kw) x_2d = out["logits_2d"] = head(x_2d) x = out["logits"] = head(x) return x, out def Module(num_classes=None, *, variant=None, **kw): # pylint: disable=invalid-name # noqa: N802 """Factory function, because linen really don't like what I'm doing!""" return _Module(num_classes, **{**decode_variant(variant), **kw}) def decode_variant(variant): """Converts a string like "B" or "B/32" into a params dict.""" if variant is None: return {} v, patch = variant, {} if "/" in variant: v, patch = variant.split("/") patch = {"patch_size": (int(patch), int(patch))} return { # pylint:disable=line-too-long # Reference: Table 2 of https://arxiv.org/abs/2106.04560. "width": { "mu": 32, "Ti": 192, "S": 384, "M": 512, "B": 768, "L": 1024, "So400m": 1152, "H": 1280, "g": 1408, "g-opt": 1536, "G": 1664, "G-opt": 1536, "e": 1792, }[v], "depth": { "mu": 1, "Ti": 12, "S": 12, "M": 12, "B": 12, "L": 24, "So400m": 27, "H": 32, "g": 40, "g-opt": 40, "G": 48, "G-opt": 48, "e": 56, }[v], "mlp_dim": { "mu": 128, "Ti": 768, "S": 1536, "M": 2048, "B": 3072, "L": 4096, "So400m": 4304, "H": 5120, "g": 6144, "g-opt": 6144, "G": 8192, "G-opt": 8192, "e": 15360, }[v], "num_heads": { "mu": 2, "Ti": 3, "S": 6, "M": 8, "B": 12, "L": 16, "So400m": 16, "H": 16, "g": 16, "g-opt": 16, "G": 16, "G-opt": 16, "e": 16, }[v], # pylint:enable=line-too-long **patch, } ================================================ FILE: src/openpi/models/tokenizer.py ================================================ import logging import os import jax import numpy as np import orbax.checkpoint as ocp import sentencepiece from transformers import AutoProcessor import openpi.models.utils.fsq_tokenizer as fsq_tokenizer import openpi.shared.download as download class PaligemmaTokenizer: def __init__(self, max_len: int = 48): self._max_len = max_len path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"}) with path.open("rb") as f: self._tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read()) def tokenize(self, prompt: str, state: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]: cleaned_text = prompt.strip().replace("_", " ").replace("\n", " ") if state is not None: # This is the Pi05 format, where the state is part of the discrete language input. discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 state_str = " ".join(map(str, discretized_state)) full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " tokens = self._tokenizer.encode(full_prompt, add_bos=True) else: # This is the Pi0 format, where the state is part of the continuous action expert input. # tokenize "\n" separately as the "start of answer" token tokens = self._tokenizer.encode(cleaned_text, add_bos=True) + self._tokenizer.encode("\n") tokens_len = len(tokens) if tokens_len < self._max_len: padding = [False] * (self._max_len - tokens_len) mask = [True] * tokens_len + padding tokens = tokens + padding else: if len(tokens) > self._max_len: logging.warning( f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. " "Consider increasing the `max_token_len` in your model config if this happens frequently." ) tokens = tokens[: self._max_len] mask = [True] * self._max_len return np.asarray(tokens), np.asarray(mask) class FASTTokenizer: def __init__(self, max_len: int = 256, fast_tokenizer_path: str = "physical-intelligence/fast"): self._max_len = max_len # Download base PaliGemma tokenizer path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"}) with path.open("rb") as f: self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read()) # Instantiate FAST tokenizer self._fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True) self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens def tokenize( self, prompt: str, state: np.ndarray, actions: np.ndarray | None ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: cleaned_text = prompt.lower().strip().replace("_", " ") # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1]) discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 # Convention: prefix includes prompt and string-representation of state, followed by ';' state_str = " ".join(map(str, discretized_state)) prefix = f"Task: {cleaned_text}, State: {state_str};\n" prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True) if actions is not None: # Tokenize actions with FAST tokenizer --> map to last tokens in PaliGemma vocab action_tokens = self._fast_tokenizer(actions[None])[0] action_tokens_in_pg = self._act_tokens_to_paligemma_tokens(action_tokens) # Convention: postfix contains 'Action:' followed by FAST tokens, followed by '|' postfix_tokens = ( self._paligemma_tokenizer.encode("Action: ") + action_tokens_in_pg.tolist() + self._paligemma_tokenizer.encode("|", add_eos=True) ) else: postfix_tokens = [] # Create output token sequence & masks # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens) tokens = prefix_tokens + postfix_tokens token_mask = [True] * len(tokens) ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens) loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only # Pad tokens to max length tokens_len = len(tokens) if tokens_len < self._max_len: padding = [False] * (self._max_len - tokens_len) tokens = tokens + padding token_mask = token_mask + padding ar_mask = ar_mask + padding loss_mask = loss_mask + padding else: if len(tokens) > self._max_len: logging.warning( f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. " "Consider increasing the `max_token_len` in your model config if this happens frequently." ) tokens = tokens[: self._max_len] token_mask = token_mask[: self._max_len] ar_mask = ar_mask[: self._max_len] loss_mask = loss_mask[: self._max_len] return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask) def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray: # Decode predicted output tokens decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist()) # Extract actions from FAST model outputs if "Action: " not in decoded_tokens: return np.zeros((action_horizon, action_dim), dtype=np.float32) # Extract actions from decoded tokens raw_action_tokens = np.array( self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip()) ) action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens) return self._fast_tokenizer.decode( [action_tokens.tolist()], time_horizon=action_horizon, action_dim=action_dim )[0] def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray: if isinstance(tokens, list): tokens = np.array(tokens) return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens ########################################################################### ## The tokenizers below are used for RoboArena baseline implementations. ## ## They are *not* used for pi0-style models. ## ########################################################################### class BinningTokenizer: """ Standard RT-2 / OpenVLA style binning tokenizer. """ def __init__(self, max_len: int = 256, n_bins: int = 256): self._max_len = max_len self._n_bins = n_bins # Download base PaliGemma tokenizer path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"}) with path.open("rb") as f: self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read()) self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens def tokenize( self, prompt: str, state: np.ndarray, actions: np.ndarray | None ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Tokenize a prompt and state into a sequence of tokens. Args: prompt: The text prompt to tokenize. state: The state array to discretize and tokenize. actions: Must be None. Action encoding is not currently supported. Returns: A tuple of (tokens, token_mask, ar_mask, targets). Raises: NotImplementedError: If actions is not None. """ cleaned_text = prompt.lower().strip().replace("_", " ") # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1]) discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 # Convention: prefix includes prompt and string-representation of state, followed by ';' state_str = " ".join(map(str, discretized_state)) prefix = f"Task: {cleaned_text}, State: {state_str};\n" prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True) if actions is not None: raise NotImplementedError("BinningTokenizer does not support encoding actions atm (only for inference use)") postfix_tokens = [] # Create output token sequence & masks # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens) tokens = prefix_tokens + postfix_tokens token_mask = [True] * len(tokens) ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens) loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only # Pad tokens to max length tokens_len = len(tokens) if tokens_len < self._max_len: padding = [False] * (self._max_len - tokens_len) tokens = tokens + padding token_mask = token_mask + padding ar_mask = ar_mask + padding loss_mask = loss_mask + padding else: if len(tokens) > self._max_len: logging.warning( f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. " "Consider increasing the `max_token_len` in your model config if this happens frequently." ) tokens = tokens[: self._max_len] token_mask = token_mask[: self._max_len] ar_mask = ar_mask[: self._max_len] loss_mask = loss_mask[: self._max_len] return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask) def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray: # Decode predicted output tokens decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist()) # Extract actions from FAST model outputs if "Action: " not in decoded_tokens: return np.zeros((action_horizon, action_dim), dtype=np.float32) # Extract actions from decoded tokens raw_action_tokens = np.array( self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip()) ) action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens) if len(action_tokens) < action_horizon * action_dim: return np.zeros([action_horizon, action_dim], dtype=np.float32) action_tokens = action_tokens[: (action_horizon * action_dim)].reshape([action_horizon, action_dim]) return action_tokens / self._n_bins * 2 - 1 def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray: if isinstance(tokens, list): tokens = np.array(tokens) return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens class FSQTokenizer: """ FSQ tokenizer from the FAST paper baselines. """ def __init__(self, max_len: int = 256, fsq_tokenizer_path: str | None = None): self._max_len = max_len assert fsq_tokenizer_path is not None, "fsq_tokenizer_path must be provided" # Download tokenizer path = download.maybe_download(fsq_tokenizer_path) tok_path = os.path.join(path, os.listdir(path)[0]) # Split step from path step = int(tok_path.split("/")[-1]) base_path = tok_path.rsplit("/", 1)[0] mgr = ocp.CheckpointManager( base_path, item_handlers={ "params": ocp.StandardCheckpointHandler(), "opt_state": ocp.StandardCheckpointHandler(), "config": ocp.JsonCheckpointHandler(), }, options=ocp.CheckpointManagerOptions(max_to_keep=1), ) try: restored = mgr.restore( step, args=ocp.args.Composite(config=ocp.args.JsonRestore(), params=ocp.args.StandardRestore()) ) config = restored["config"] self._params = restored["params"] self._fsq_tokenizer = fsq_tokenizer.FsqAttentionTokenizer(**config) except Exception as e: raise RuntimeError( f"Failed to load FSQ tokenizer checkpoint from {fsq_tokenizer_path}. Error: {e!s}" ) from e # Compile tokenize and detokenize functions self._tokenize_fn = jax.jit( lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.tokenize) ) self._detokenize_fn = jax.jit( lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.detokenize) ) # Download base PaliGemma tokenizer path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"}) with path.open("rb") as f: self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read()) self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens def tokenize( self, prompt: str, state: np.ndarray, actions: np.ndarray | None ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: cleaned_text = prompt.lower().strip().replace("_", " ") # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1]) discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 # Convention: prefix includes prompt and string-representation of state, followed by ';' state_str = " ".join(map(str, discretized_state)) prefix = f"Task: {cleaned_text}, State: {state_str};\n" prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True) if actions is not None: raise NotImplementedError("FSQTokenizer does not support encoding actions atm (only for inference use)") postfix_tokens = [] # Create output token sequence & masks # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens) tokens = prefix_tokens + postfix_tokens token_mask = [True] * len(tokens) ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens) loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only # Pad tokens to max length tokens_len = len(tokens) if tokens_len < self._max_len: padding = [False] * (self._max_len - tokens_len) tokens = tokens + padding token_mask = token_mask + padding ar_mask = ar_mask + padding loss_mask = loss_mask + padding else: if len(tokens) > self._max_len: logging.warning( f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. " "Consider increasing the `max_token_len` in your model config if this happens frequently." ) tokens = tokens[: self._max_len] token_mask = token_mask[: self._max_len] ar_mask = ar_mask[: self._max_len] loss_mask = loss_mask[: self._max_len] return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask) def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray: # Decode predicted output tokens decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist()) # Extract actions from FAST model outputs if "Action: " not in decoded_tokens: return np.zeros((action_horizon, action_dim), dtype=np.float32) # Extract actions from decoded tokens raw_action_tokens = np.array( self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip()) ) action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens) try: # Move computation to CPU and compile on-demand device = jax.devices("cpu")[0] with jax.default_device(device): detok_act = self._detokenize_fn(self._params, action_tokens[None, ...])[0] return detok_act[: action_horizon * action_dim].reshape([action_horizon, action_dim]) except Exception as e: logging.warning(f"Error decoding FSQ: {e}") return np.zeros((action_horizon, action_dim)) def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray: if isinstance(tokens, list): tokens = np.array(tokens) return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens ================================================ FILE: src/openpi/models/tokenizer_test.py ================================================ import numpy as np from openpi.models import tokenizer as _tokenizer def test_tokenize(): tokenizer = _tokenizer.PaligemmaTokenizer(max_len=10) tokens, masks = tokenizer.tokenize("Hello, world!") assert tokens.shape == (10,) assert masks.shape == (10,) def test_fast_tokenizer(): prompt = "Hello, world!" state = np.random.rand(5).astype(np.float32) action = np.random.rand(3, 2).astype(np.float32) tokenizer = _tokenizer.FASTTokenizer(max_len=256) tokens, token_masks, ar_masks, loss_masks = tokenizer.tokenize(prompt, state, action) assert tokens.shape == (256,) assert token_masks.shape == (256,) assert ar_masks.shape == (256,) assert loss_masks.shape == (256,) act = tokenizer.extract_actions(tokens, 3, 2) assert act.shape == (3, 2) ================================================ FILE: src/openpi/models/utils/fsq_tokenizer.py ================================================ import math from typing import Any, Literal import chex from einops import einops from flax import linen as nn from flax.linen.module import Module from flax.linen.module import compact from flax.struct import dataclass from flax.typing import Array import jax import jax.numpy as jnp class FsqCodebook(nn.Module): input_dim: int target_codebook_size: int codebook_type: Literal["fsq", "lfq"] _bins_per_dim: tuple[int] | None = None @property def bins_per_dim(self) -> tuple[int]: if self._bins_per_dim is not None: return self._bins_per_dim if self.codebook_type == "fsq": return self._get_bins_fsq(self.target_codebook_size) elif self.codebook_type == "lfq": # noqa: RET505 return self._get_bins_lfq(self.target_codebook_size) elif self.codebook_type == "custom": return self._get_bins_custom(self.target_codebook_size) else: raise ValueError(f"Codebook type {self.codebook_type} not supported.") @property def place_values(self) -> jnp.ndarray: place_values = [1] for b in self.bins_per_dim[:-1]: place_values.append(place_values[-1] * b) return jnp.array(place_values) @staticmethod def _get_bins_fsq(target_codebook_size: int) -> tuple[int]: """ Get bins per dimension based on codebook size, from the original FSQ paper. """ if target_codebook_size == 2**8: return (8, 6, 5) elif target_codebook_size == 2**10: # noqa: RET505 return (8, 5, 5, 5) elif target_codebook_size == 2**12: return (7, 5, 5, 5, 5) elif target_codebook_size == 2**14: return (8, 8, 8, 6, 5) elif target_codebook_size == 2**16: return (8, 8, 8, 5, 5, 5) else: raise ValueError(f"Codebook size {target_codebook_size} not supported.") @staticmethod def _get_bins_custom(target_codebook_size: int) -> tuple[int]: if target_codebook_size == 2**8: return (16, 16) elif target_codebook_size == 2**10: # noqa: RET505 return (32, 32) elif target_codebook_size == 2**12: return (64, 64) elif target_codebook_size == 2**14: return (128, 128) elif target_codebook_size == 2**16: return (256, 256) return None @staticmethod def _get_bins_lfq(target_codebook_size: int) -> tuple[int]: """ Get bins per dimension according to the Lookup-Free Quantization paper (2 bins per dimension) """ assert target_codebook_size & (target_codebook_size - 1) == 0, "Codebook size should be a power of two for LFQ" return (2,) * int(math.log2(target_codebook_size)) def setup(self): self.proj_down = nn.Dense(len(self.bins_per_dim)) self.proj_up = nn.Dense(self.input_dim) def __call__(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: tokens, z = self.encode(inputs) output = self.decode(tokens, z_grad=z) return tokens, output def encode(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: bases = jnp.array(self.bins_per_dim) x = self.proj_down(inputs) z = jnp.tanh(x) # Quantize digits = jnp.round((z + 1) * (bases - 1) / 2).astype(jnp.int32) tokens = self.undigitize(digits) return tokens, z def decode(self, tokens: jnp.ndarray, z_grad: jax.Array | None = None) -> jnp.ndarray: bases = jnp.array(self.bins_per_dim) digits = self.digitize(tokens) z_q = digits / (bases - 1) * 2 - 1 if z_grad is not None: chex.assert_equal_shape([z_q, z_grad]) z_q = jax.lax.stop_gradient(z_q - z_grad) + z_grad return self.proj_up(z_q) def undigitize(self, digits: jnp.ndarray) -> jnp.ndarray: return jnp.sum(digits * jnp.array(self.place_values), axis=-1) def digitize(self, tokens: jnp.ndarray) -> jnp.ndarray: return (tokens[..., None] // jnp.array(self.place_values)) % jnp.array(self.bins_per_dim) @property def vocab_size(self) -> int: return math.prod(self.bins_per_dim) class ResNetDownBlock(nn.Module): stride: int = 1 n_filters: int = 64 dropout_rate: float = 0.0 group_size: int = 32 @nn.compact def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray: skip = x if self.stride > 1 or x.shape[-1] != self.n_filters: skip = nn.Conv(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip) x = nn.Conv(self.n_filters, (3,), (self.stride,), "SAME")(x) x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x) x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) x = nn.relu(x) x = nn.Conv(self.n_filters, (3,), (1,), "SAME")(x) return skip + x class ResNetUpBlock(nn.Module): stride: int = 1 n_filters: int = 64 dropout_rate: float = 0.0 group_size: int = 32 @nn.compact def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray: skip = x if self.stride > 1: skip = nn.ConvTranspose(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip) x = nn.ConvTranspose(self.n_filters, (3,), (self.stride,), "SAME")(x) x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x) x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) x = nn.relu(x) x = nn.ConvTranspose(self.n_filters, (3,), (1,), "SAME")(x) return skip + x @dataclass class LfqCodebookOutput: tokens: jnp.ndarray z: jnp.ndarray z_q: jnp.ndarray token_log_probs: jnp.ndarray commit_loss: jnp.ndarray class LookupFreeQuantization(nn.Module): num_dims: int latent_dim: int def setup(self): self.codebook = jnp.array([-1, 1]) self.activation = nn.tanh self.project_down = nn.Dense(self.num_dims) self.project_up = nn.Dense(self.latent_dim) def encode(self, z: jnp.ndarray) -> jnp.ndarray: z = self.project_down(z) token_squared_distances = jnp.square(z[..., None] - self.codebook) token_bits = jnp.argmin(token_squared_distances, axis=-1) return jnp.sum(token_bits * (2 ** jnp.arange(self.num_dims)), axis=-1) def decode(self, tokens: jnp.ndarray) -> jnp.ndarray: token_bits = (tokens[..., None] & (2 ** jnp.arange(self.num_dims))).astype(jnp.int32) return self.project_up(self.codebook[token_bits]) def loss(self, x: jnp.ndarray) -> LfqCodebookOutput: z = self.project_down(x) z = self.activation(z) token_squared_distances = jnp.square(z[..., None] - self.codebook) tokens = jnp.argmin(token_squared_distances, axis=-1) token_bit_log_probs = -token_squared_distances # Compute token log probs for tokens 0..2^num_dims-1 by summing corresponding log-probs token_bit_expansions = jnp.bitwise_and( jnp.arange(2**self.num_dims)[None, :], 2 ** jnp.arange(self.num_dims)[:, None] ).astype(jnp.int32) token_log_probs = ( token_bit_log_probs[..., 0] @ (1 - token_bit_expansions) + token_bit_log_probs[..., 1] @ token_bit_expansions ) # (batch_size, num_tokens, 2 ** num_dims) token_log_probs = jax.lax.stop_gradient(jax.nn.log_softmax(token_log_probs, axis=-1)) chex.assert_shape(token_log_probs, (*x.shape[:-1], 2**self.num_dims)) z_q = self.codebook[tokens] commit_loss = jnp.square(z - z_q).mean() z_q = jax.lax.stop_gradient(z_q - z) + z z_q = self.project_up(z_q) z = self.project_up(z) tokens = jnp.sum(tokens * (len(self.codebook) ** jnp.arange(self.num_dims)), axis=-1) return LfqCodebookOutput( tokens=tokens, z=z, z_q=z_q, token_log_probs=jnp.zeros(()), commit_loss=commit_loss, ) def make_block_causal_attention_matrix(q: jnp.ndarray, k: jnp.ndarray, bs_q: int, bs_k: int) -> jnp.ndarray: return nn.make_attention_mask(q, k, pairwise_fn=lambda x, y: jnp.greater_equal(x // bs_k, y // bs_q)) class GeGLU(Module): """Gated Linear Unit with GELU (GeGLU) activation function. GeGLU is a Flax layer that combines a linear transformation with a GELU activation function in a gating mechanism. It is often used in Transformer models to provide non-linear capabilities while preserving a strong linear component. Attributes: features: the number of output features (default: None). """ output_dim: int = -1 @compact def __call__(self, inputs: Array) -> Array: """Applies the GeGLU activation to the inputs. Args: inputs: the nd-array to apply the GeGLU activation function to. Returns: The transformed input. """ output_dim = inputs.shape[-1] if self.output_dim == -1 else self.output_dim x = nn.Dense(output_dim * 2)(inputs) x, gate = x[..., :output_dim], x[..., output_dim:] return x * nn.gelu(gate) class CrossAttentionLayer(nn.Module): dropout_rate: float = 0.0 num_heads: int = None causal: bool = False mlp_ratio: float = 4.0 @nn.compact def __call__( self, x: jnp.ndarray, y: jnp.ndarray, *, mask_self: jnp.ndarray | None = None, mask_cross: jnp.ndarray | None = None, train: bool = True, ) -> jnp.ndarray: d_embed = x.shape[-1] seq_len_q = x.shape[-2] seq_len_k = y.shape[-2] if self.causal: # One block size will be 1 bs_q = max(seq_len_q // seq_len_k, 1) bs_k = max(seq_len_k // seq_len_q, 1) mask_self = nn.make_causal_mask(x[..., 0]) mask_cross = make_block_causal_attention_matrix(x[..., 0], y[..., 0], bs_q, bs_k) # Self-attention block skip = x x = nn.LayerNorm()(x) x = nn.MultiHeadDotProductAttention( num_heads=self.num_heads or d_embed // 64, dropout_rate=self.dropout_rate, deterministic=not train, )(x, x, x, mask=mask_self) x = skip + x # Cross-attention block skip = x x = nn.LayerNorm()(x) x = nn.MultiHeadDotProductAttention( num_heads=self.num_heads or d_embed // 64, dropout_rate=self.dropout_rate, deterministic=not train, )(x, y, y, mask=mask_cross) x = skip + x # MLP block skip = x x = nn.LayerNorm()(x) x = nn.Dense(int(d_embed * self.mlp_ratio))(x) x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) x = GeGLU()(x) x = nn.Dense(d_embed)(x) return skip + x def sinusoidal_pe_init(_, shape: tuple[int, int]) -> jnp.ndarray: seq_len, d_embed = shape position = jnp.arange(0, seq_len, 1) div_term = jnp.exp(jnp.arange(0, d_embed, 2) * -(jnp.log(10000.0) / d_embed)) return jnp.concatenate( [ jnp.sin(position[:, jnp.newaxis] * div_term), jnp.cos(position[:, jnp.newaxis] * div_term), ], axis=-1, ) class TokenizerEncoderDecoder(nn.Module): num_tokens: int num_cross_tokens: int num_layers: int causal: bool mlp_ratio: float = 4.0 use_state_conditioning: bool = False @nn.compact def __call__( self, y: jnp.ndarray, *, train: bool = True, state_conditioning: jnp.ndarray | None = None, mask: jnp.ndarray | None = None, ) -> jnp.ndarray: x = self.param("q_embed", sinusoidal_pe_init, (self.num_tokens, y.shape[-1])) x = jax.numpy.broadcast_to(x, y.shape[:-2] + x.shape[-2:]) if mask is not None: # mask is (batch_dims..., num_cross_tokens) chex.assert_equal_shape([y[..., 0], mask]) attn_mask = einops.repeat(mask, "... kv -> ... 1 q kv", q=self.num_tokens) else: attn_mask = jnp.ones((*y.shape[:-2], 1, self.num_tokens, self.num_cross_tokens)) if self.use_state_conditioning: assert state_conditioning is not None, "State conditioning is required for this model." state_embed = nn.Dense(y.shape[-1], name="state_proj")(state_conditioning)[..., None, :] y = jnp.concatenate([y, state_embed], axis=-2) attn_mask = jnp.concatenate([attn_mask, jnp.ones_like(attn_mask[..., 0:1])], axis=-1) y = y + self.param("y_pos_enc", sinusoidal_pe_init, y.shape[-2:]) for _ in range(self.num_layers): x = CrossAttentionLayer(causal=self.causal, mlp_ratio=self.mlp_ratio)( x, y, train=train, mask_self=None, mask_cross=attn_mask ) return x class FsqAttentionTokenizer(nn.Module): embed_dim: int data_dim: int data_horizon: int num_tokens: int num_layers: int target_codebook_size: int causal: bool = False mlp_ratio: float = 2.0 bound: float | None = None use_state_conditioning: bool = False @property def vocab_size(self) -> int: return math.prod(FsqCodebook._get_bins_fsq(self.target_codebook_size)) # noqa: SLF001 def setup(self): self.proj = nn.Dense(self.embed_dim) self.encoder = TokenizerEncoderDecoder( num_tokens=self.num_tokens, num_cross_tokens=self.data_horizon, num_layers=self.num_layers, causal=self.causal, use_state_conditioning=self.use_state_conditioning, mlp_ratio=self.mlp_ratio, ) self.codebook = FsqCodebook( input_dim=self.embed_dim, target_codebook_size=self.target_codebook_size, codebook_type="custom", ) self.decoder = TokenizerEncoderDecoder( num_tokens=self.data_horizon, num_cross_tokens=self.num_tokens, num_layers=self.num_layers, causal=self.causal, use_state_conditioning=self.use_state_conditioning, mlp_ratio=self.mlp_ratio, ) self.proj_mean = nn.Dense(self.data_dim) self.out_scale = self.param("out_scale", lambda _: jnp.full((), 1.0)) def tokenize( self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = False ) -> tuple[jnp.ndarray, jnp.ndarray]: if self.bound is not None: action = jnp.clip(action, -self.bound, self.bound) x = self.proj(action) x = self.encoder(x, train=train, state_conditioning=obs) return self.codebook.encode(x) def detokenize(self, tokens: jnp.ndarray, *, obs: jnp.ndarray | None = None) -> jnp.ndarray: x = self.decoder(self.codebook.decode(tokens), state_conditioning=obs) mean = self.proj_mean(x) return mean * self.out_scale def loss( self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = True ) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]: # Encode x = self.proj(action) z = self.encoder(x, train=train, state_conditioning=obs) # Quantize tokens, z = self.codebook(z) # Decode x = self.decoder(z, train=train, state_conditioning=obs) mean = self.proj_mean(x) * self.out_scale mse = jnp.mean(jnp.square(action - mean)) mae = jnp.mean(jnp.abs(action - mean)) return mse, { "mse": mse, "mae": mae, } def __call__(self, *args: Any, **kwargs: Any) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]: """ Dummy for .init """ return self.loss(*args, **kwargs) ================================================ FILE: src/openpi/models/vit.py ================================================ # Copyright 2024 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ViT implementation adapted from https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py.""" from collections.abc import Callable from typing import Any import flax.linen as nn import jax import jax.numpy as jnp from openpi.models import resnet as models_resnet Array = Any PRNGKey = Any Shape = tuple[int] Dtype = Any class IdentityLayer(nn.Module): """Identity layer, convenient for giving a name to an array.""" @nn.compact def __call__(self, x): return x class AddPositionEmbs(nn.Module): """Adds learned positional embeddings to the inputs. Attributes: posemb_init: positional embedding initializer. """ posemb_init: Callable[[PRNGKey, Shape, Dtype], Array] param_dtype: Dtype = jnp.float32 @nn.compact def __call__(self, inputs): """Applies the AddPositionEmbs module. Args: inputs: Inputs to the layer. Returns: Output tensor with shape `(bs, timesteps, in_dim)`. """ # inputs.shape is (batch_size, seq_len, emb_dim). assert inputs.ndim == 3, f"Number of dimensions should be 3, but it is: {inputs.ndim}" pos_emb_shape = (1, inputs.shape[1], inputs.shape[2]) pe = self.param("pos_embedding", self.posemb_init, pos_emb_shape, self.param_dtype) return inputs + pe class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" mlp_dim: int dtype: Dtype = jnp.float32 param_dtype: Dtype = jnp.float32 out_dim: int | None = None dropout_rate: float = 0.1 kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_uniform() bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.normal(stddev=1e-6) @nn.compact def __call__(self, inputs, *, deterministic): """Applies Transformer MlpBlock module.""" actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim x = nn.Dense( features=self.mlp_dim, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, )( # pytype: disable=wrong-arg-types inputs ) x = nn.gelu(x) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) output = nn.Dense( features=actual_out_dim, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, )( # pytype: disable=wrong-arg-types x ) return nn.Dropout(rate=self.dropout_rate)(output, deterministic=deterministic) class Encoder1DBlock(nn.Module): """Transformer encoder layer. Attributes: inputs: input data. mlp_dim: dimension of the mlp on top of attention block. dtype: the dtype of the computation (default: float32). dropout_rate: dropout rate. attention_dropout_rate: dropout for attention heads. deterministic: bool, deterministic or not (to apply dropout). num_heads: Number of heads in nn.MultiHeadDotProductAttention """ mlp_dim: int num_heads: int dtype: Dtype = jnp.float32 dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 @nn.compact def __call__(self, inputs, deterministic): """Applies Encoder1DBlock module. Args: inputs: Inputs to the layer. deterministic: Dropout will not be applied when set to true. Returns: output after transformer encoder block. """ # Attention block. assert inputs.ndim == 3, f"Expected (batch, seq, hidden) got {inputs.shape}" x = nn.LayerNorm(dtype=self.dtype)(inputs) x = nn.MultiHeadDotProductAttention( dtype=self.dtype, kernel_init=nn.initializers.xavier_uniform(), broadcast_dropout=False, deterministic=deterministic, dropout_rate=self.attention_dropout_rate, num_heads=self.num_heads, # why isn't this true by default??? force_fp32_for_softmax=True, )(x, x) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(dtype=self.dtype)(x) y = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)( y, deterministic=deterministic ) return x + y, None class Encoder(nn.Module): """Transformer Model Encoder for sequence to sequence translation. Attributes: num_layers: number of layers mlp_dim: dimension of the mlp on top of attention block num_heads: Number of heads in nn.MultiHeadDotProductAttention dropout_rate: dropout rate. attention_dropout_rate: dropout rate in self attention. """ dtype: jax.typing.DTypeLike num_layers: int mlp_dim: int num_heads: int dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 add_position_embedding: bool = True @nn.compact def __call__(self, x, *, train): """Applies Transformer model on the inputs. Args: x: Inputs to the layer. train: Set to `True` when training. Returns: output of a transformer encoder. """ assert x.ndim == 3 # (batch, len, emb) if self.add_position_embedding: x = AddPositionEmbs( posemb_init=nn.initializers.normal(stddev=0.02), # from BERT. name="posembed_input", )(x) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) x = x.astype(self.dtype) # Input Encoder block = nn.remat(Encoder1DBlock, prevent_cse=False, static_argnums=(2,)) x, _ = nn.scan( block, variable_axes={"params": 0}, split_rngs={"params": True, "dropout": True}, in_axes=nn.broadcast, length=self.num_layers, )( name="encoderblock", mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, dtype=self.dtype, num_heads=self.num_heads, )(x, not train) return nn.LayerNorm(name="encoder_norm", dtype=self.dtype)(x) class VisionTransformer(nn.Module): """VisionTransformer.""" dtype: jax.typing.DTypeLike num_classes: int patches: Any transformer: Any hidden_size: int resnet: Any | None = None representation_size: int | None = None classifier: str = "token" head_bias_init: float = 0.0 encoder: type[nn.Module] = Encoder model_name: str | None = None @nn.compact def __call__(self, inputs, *, train): x = inputs # (Possibly partial) ResNet root. if self.resnet is not None: width = int(64 * self.resnet.width_factor) # Root block. x = models_resnet.StdConv( features=width, kernel_size=(7, 7), strides=(2, 2), use_bias=False, name="conv_root" )(x) x = nn.GroupNorm(name="gn_root")(x) x = nn.relu(x) x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME") # ResNet stages. if self.resnet.num_layers: x = models_resnet.ResNetStage( block_size=self.resnet.num_layers[0], nout=width, first_stride=(1, 1), name="block1" )(x) for i, block_size in enumerate(self.resnet.num_layers[1:], 1): x = models_resnet.ResNetStage( block_size=block_size, nout=width * 2**i, first_stride=(2, 2), name=f"block{i + 1}" )(x) n, h, w, c = x.shape # We can merge s2d+emb into a single conv; it's the same. x = nn.Conv( features=self.hidden_size, kernel_size=self.patches.size, strides=self.patches.size, padding="VALID", name="embedding", )(x) # Here, x is a grid of embeddings. # (Possibly partial) Transformer. if self.transformer is not None: n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # If we want to add a class token, add it here. if self.classifier in ["token", "token_unpooled"]: cls = self.param("cls", nn.initializers.zeros, (1, 1, c)) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) x = self.encoder(name="Transformer", **self.transformer, dtype=self.dtype)(x, train=train) if self.classifier == "token": x = x[:, 0] elif self.classifier == "gap": x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) elif self.classifier in ["unpooled", "token_unpooled"]: pass else: raise ValueError(f"Invalid classifier={self.classifier}") if self.representation_size is not None: x = nn.Dense(features=self.representation_size, name="pre_logits")(x) x = nn.tanh(x) else: x = IdentityLayer(name="pre_logits")(x) if self.num_classes: x = nn.Dense( features=self.num_classes, name="head", kernel_init=nn.initializers.zeros, bias_init=nn.initializers.constant(self.head_bias_init), )(x) return x ================================================ FILE: src/openpi/models_pytorch/gemma_pytorch.py ================================================ from typing import Literal import pytest import torch from torch import nn from transformers import GemmaForCausalLM from transformers import PaliGemmaForConditionalGeneration from transformers.models.auto import CONFIG_MAPPING from transformers.models.gemma import modeling_gemma class PaliGemmaWithExpertModel(nn.Module): def __init__( self, vlm_config, action_expert_config, use_adarms=None, precision: Literal["bfloat16", "float32"] = "bfloat16", ): if use_adarms is None: use_adarms = [False, False] super().__init__() vlm_config_hf = CONFIG_MAPPING["paligemma"]() vlm_config_hf._vocab_size = 257152 # noqa: SLF001 vlm_config_hf.image_token_index = 257152 vlm_config_hf.text_config.hidden_size = vlm_config.width vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads vlm_config_hf.text_config.head_dim = vlm_config.head_dim vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" vlm_config_hf.text_config.torch_dtype = "float32" vlm_config_hf.text_config.vocab_size = 257152 vlm_config_hf.text_config.use_adarms = use_adarms[0] vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None vlm_config_hf.vision_config.intermediate_size = 4304 vlm_config_hf.vision_config.projection_dim = 2048 vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" vlm_config_hf.vision_config.torch_dtype = "float32" action_expert_config_hf = CONFIG_MAPPING["gemma"]( head_dim=action_expert_config.head_dim, hidden_size=action_expert_config.width, intermediate_size=action_expert_config.mlp_dim, num_attention_heads=action_expert_config.num_heads, num_hidden_layers=action_expert_config.depth, num_key_value_heads=action_expert_config.num_kv_heads, vocab_size=257152, hidden_activation="gelu_pytorch_tanh", torch_dtype="float32", use_adarms=use_adarms[1], adarms_cond_dim=action_expert_config.width if use_adarms[1] else None, ) self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf) self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf) self.gemma_expert.model.embed_tokens = None self.to_bfloat16_for_selected_params(precision) def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): if precision == "bfloat16": self.to(dtype=torch.bfloat16) elif precision == "float32": self.to(dtype=torch.float32) return else: raise ValueError(f"Invalid precision: {precision}") params_to_keep_float32 = [ "vision_tower.vision_model.embeddings.patch_embedding.weight", "vision_tower.vision_model.embeddings.patch_embedding.bias", "vision_tower.vision_model.embeddings.position_embedding.weight", "input_layernorm", "post_attention_layernorm", "model.norm", ] for name, param in self.named_parameters(): if any(selector in name for selector in params_to_keep_float32): param.data = param.data.to(dtype=torch.float32) def embed_image(self, image: torch.Tensor): return self.paligemma.model.get_image_features(image) def embed_language_tokens(self, tokens: torch.Tensor): return self.paligemma.language_model.embed_tokens(tokens) def forward( self, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: list[torch.FloatTensor] | pytest.Cache | None = None, inputs_embeds: list[torch.FloatTensor] | None = None, use_cache: bool | None = None, adarms_cond: list[torch.Tensor] | None = None, ): if adarms_cond is None: adarms_cond = [None, None] if inputs_embeds[1] is None: prefix_output = self.paligemma.language_model.forward( inputs_embeds=inputs_embeds[0], attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, adarms_cond=adarms_cond[0] if adarms_cond is not None else None, ) prefix_past_key_values = prefix_output.past_key_values prefix_output = prefix_output.last_hidden_state suffix_output = None elif inputs_embeds[0] is None: suffix_output = self.gemma_expert.model.forward( inputs_embeds=inputs_embeds[1], attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, adarms_cond=adarms_cond[1] if adarms_cond is not None else None, ) suffix_output = suffix_output.last_hidden_state prefix_output = None prefix_past_key_values = None else: models = [self.paligemma.language_model, self.gemma_expert.model] num_layers = self.paligemma.config.text_config.num_hidden_layers # Check if gradient checkpointing is enabled for any of the models use_gradient_checkpointing = ( hasattr(self.gemma_expert.model, "gradient_checkpointing") and self.gemma_expert.model.gradient_checkpointing and self.training ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training) # Force enable gradient checkpointing if we're in training mode and the model supports it if self.training and hasattr(self.gemma_expert.model, "gradient_checkpointing"): if not self.gemma_expert.model.gradient_checkpointing: print("Forcing gradient checkpointing to be enabled for Gemma expert model") self.gemma_expert.model.gradient_checkpointing = True use_gradient_checkpointing = True # Debug gradient checkpointing status if hasattr(self, "_debug_gc_printed") and not self._debug_gc_printed: print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}") print(f"Model training mode: {self.training}") print( f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}" ) if hasattr(self.gemma_expert.model, "gradient_checkpointing"): print( f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}" ) self._debug_gc_printed = True # Define the complete layer computation function for gradient checkpointing def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond): models = [self.paligemma.language_model, self.gemma_expert.model] query_states = [] key_states = [] value_states = [] gates = [] for i, hidden_states in enumerate(inputs_embeds): layer = models[i].layers[layer_idx] hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 gates.append(gate) input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) query_states.append(query_state) key_states.append(key_state) value_states.append(value_state) # Concatenate and process attention query_states = torch.cat(query_states, dim=2) key_states = torch.cat(key_states, dim=2) value_states = torch.cat(value_states, dim=2) dummy_tensor = torch.zeros( query_states.shape[0], query_states.shape[2], query_states.shape[-1], device=query_states.device, dtype=query_states.dtype, ) cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) query_states, key_states = modeling_gemma.apply_rotary_pos_emb( query_states, key_states, cos, sin, unsqueeze_dim=1 ) batch_size = query_states.shape[0] scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling # Attention computation att_output, _ = modeling_gemma.eager_attention_forward( self.paligemma.language_model.layers[layer_idx].self_attn, query_states, key_states, value_states, attention_mask, scaling, ) # Get head_dim from the current layer, not from the model head_dim = self.paligemma.language_model.layers[layer_idx].self_attn.head_dim att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) # Process layer outputs outputs_embeds = [] start_pos = 0 for i, hidden_states in enumerate(inputs_embeds): layer = models[i].layers[layer_idx] end_pos = start_pos + hidden_states.shape[1] if att_output.dtype != layer.self_attn.o_proj.weight.dtype: att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) # first residual out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 after_first_residual = out_emb.clone() out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) # Convert to bfloat16 if the next layer (mlp) uses bfloat16 if layer.mlp.up_proj.weight.dtype == torch.bfloat16: out_emb = out_emb.to(dtype=torch.bfloat16) out_emb = layer.mlp(out_emb) # second residual out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 outputs_embeds.append(out_emb) start_pos = end_pos return outputs_embeds # Process all layers with gradient checkpointing if enabled for layer_idx in range(num_layers): if use_gradient_checkpointing: inputs_embeds = torch.utils.checkpoint.checkpoint( compute_layer_complete, layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, use_reentrant=False, preserve_rng_state=False, ) else: inputs_embeds = compute_layer_complete( layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond ) # Old code removed - now using compute_layer_complete function above # final norm # Define final norm computation function for gradient checkpointing def compute_final_norms(inputs_embeds, adarms_cond): outputs_embeds = [] for i, hidden_states in enumerate(inputs_embeds): out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) outputs_embeds.append(out_emb) return outputs_embeds # Apply gradient checkpointing to final norm if enabled if use_gradient_checkpointing: outputs_embeds = torch.utils.checkpoint.checkpoint( compute_final_norms, inputs_embeds, adarms_cond, use_reentrant=False, preserve_rng_state=False ) else: outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond) prefix_output = outputs_embeds[0] suffix_output = outputs_embeds[1] prefix_past_key_values = None return [prefix_output, suffix_output], prefix_past_key_values ================================================ FILE: src/openpi/models_pytorch/pi0_pytorch.py ================================================ import logging import math import torch from torch import Tensor from torch import nn import torch.nn.functional as F # noqa: N812 import openpi.models.gemma as _gemma from openpi.models_pytorch.gemma_pytorch import PaliGemmaWithExpertModel import openpi.models_pytorch.preprocessing_pytorch as _preprocessing def get_safe_dtype(target_dtype, device_type): """Get a safe dtype for the given device type.""" if device_type == "cpu": # CPU doesn't support bfloat16, use float32 instead if target_dtype == torch.bfloat16: return torch.float32 if target_dtype == torch.float64: return torch.float64 return target_dtype def create_sinusoidal_pos_embedding( time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" ) -> Tensor: """Computes sine-cosine positional embedding vectors for scalar positions.""" if dimension % 2 != 0: raise ValueError(f"dimension ({dimension}) must be divisible by 2") if time.ndim != 1: raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") dtype = get_safe_dtype(torch.float64, device.type) fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) period = min_period * (max_period / min_period) ** fraction # Compute the outer product scaling_factor = 1.0 / period * 2 * math.pi sin_input = scaling_factor[None, :] * time[:, None] return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) def sample_beta(alpha, beta, bsize, device): alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) dist = torch.distributions.Beta(alpha_t, beta_t) return dist.sample((bsize,)) def make_att_2d_masks(pad_masks, att_masks): """Copied from big_vision. Tokens can attend to valid inputs tokens which have a cumulative mask_ar smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to setup several types of attention, for example: [[1 1 1 1 1 1]]: pure causal attention. [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between themselves and the last 3 tokens have a causal attention. The first entry could also be a 1 without changing behaviour. [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a block can attend all previous blocks and all tokens on the same block. Args: input_mask: bool[B, N] true if its part of the input, false if padding. mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on it and 0 where it shares the same attention mask as the previous token. """ if att_masks.ndim != 2: raise ValueError(att_masks.ndim) if pad_masks.ndim != 2: raise ValueError(pad_masks.ndim) cumsum = torch.cumsum(att_masks, dim=1) att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] return att_2d_masks & pad_2d_masks class PI0Pytorch(nn.Module): def __init__(self, config): super().__init__() self.config = config self.pi05 = config.pi05 paligemma_config = _gemma.get_config(config.paligemma_variant) action_expert_config = _gemma.get_config(config.action_expert_variant) self.paligemma_with_expert = PaliGemmaWithExpertModel( paligemma_config, action_expert_config, use_adarms=[False, True] if self.pi05 else [False, False], precision=config.dtype, ) self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width) self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim) if self.pi05: self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width) self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) else: self.state_proj = nn.Linear(config.action_dim, action_expert_config.width) self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width) self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) torch.set_float32_matmul_precision("high") if config.pytorch_compile_mode is not None: self.sample_actions = torch.compile(self.sample_actions, mode=config.pytorch_compile_mode) # Initialize gradient checkpointing flag self.gradient_checkpointing_enabled = False msg = "transformers_replace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/`." try: from transformers.models.siglip import check if not check.check_whether_transformers_replace_is_installed_correctly(): raise ValueError(msg) except ImportError: raise ValueError(msg) from None def gradient_checkpointing_enable(self): """Enable gradient checkpointing for memory optimization.""" self.gradient_checkpointing_enabled = True self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True logging.info("Enabled gradient checkpointing for PI0Pytorch model") def gradient_checkpointing_disable(self): """Disable gradient checkpointing.""" self.gradient_checkpointing_enabled = False self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False logging.info("Disabled gradient checkpointing for PI0Pytorch model") def is_gradient_checkpointing_enabled(self): """Check if gradient checkpointing is enabled.""" return self.gradient_checkpointing_enabled def _apply_checkpoint(self, func, *args, **kwargs): """Helper method to apply gradient checkpointing if enabled.""" if self.gradient_checkpointing_enabled and self.training: return torch.utils.checkpoint.checkpoint( func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs ) return func(*args, **kwargs) def _prepare_attention_masks_4d(self, att_2d_masks): """Helper method to prepare 4D attention masks for transformer.""" att_2d_masks_4d = att_2d_masks[:, None, :, :] return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38) def _preprocess_observation(self, observation, *, train=True): """Helper method to preprocess observation.""" observation = _preprocessing.preprocess_observation_pytorch(observation, train=train) return ( list(observation.images.values()), list(observation.image_masks.values()), observation.tokenized_prompt, observation.tokenized_prompt_mask, observation.state, ) def sample_noise(self, shape, device): return torch.normal( mean=0.0, std=1.0, size=shape, dtype=torch.float32, device=device, ) def sample_time(self, bsize, device): time_beta = sample_beta(1.5, 1.0, bsize, device) time = time_beta * 0.999 + 0.001 return time.to(dtype=torch.float32, device=device) def embed_prefix( self, images, img_masks, lang_tokens, lang_masks ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Embed images with SigLIP and language tokens with embedding layer to prepare for PaliGemma transformer processing. """ embs = [] pad_masks = [] att_masks = [] # Process images for img, img_mask in zip(images, img_masks, strict=True): def image_embed_func(img): return self.paligemma_with_expert.embed_image(img) img_emb = self._apply_checkpoint(image_embed_func, img) bsize, num_img_embs = img_emb.shape[:2] embs.append(img_emb) pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) # Create attention masks so that image tokens attend to each other att_masks += [0] * num_img_embs # Process language tokens def lang_embed_func(lang_tokens): lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) lang_emb_dim = lang_emb.shape[-1] return lang_emb * math.sqrt(lang_emb_dim) lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) embs.append(lang_emb) pad_masks.append(lang_masks) # full attention between image and language inputs num_lang_embs = lang_emb.shape[1] att_masks += [0] * num_lang_embs embs = torch.cat(embs, dim=1) pad_masks = torch.cat(pad_masks, dim=1) att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) # Get batch size from the first dimension of the concatenated tensors bsize = pad_masks.shape[0] att_masks = att_masks[None, :].expand(bsize, len(att_masks)) return embs, pad_masks, att_masks def embed_suffix(self, state, noisy_actions, timestep): """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" embs = [] pad_masks = [] att_masks = [] if not self.pi05: if self.state_proj.weight.dtype == torch.float32: state = state.to(torch.float32) # Embed state def state_proj_func(state): return self.state_proj(state) state_emb = self._apply_checkpoint(state_proj_func, state) embs.append(state_emb[:, None, :]) bsize = state_emb.shape[0] device = state_emb.device state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device) pad_masks.append(state_mask) # Set attention masks so that image and language inputs do not attend to state or actions att_masks += [1] # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] time_emb = create_sinusoidal_pos_embedding( timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0, device=timestep.device ) time_emb = time_emb.type(dtype=timestep.dtype) # Fuse timestep + action information using an MLP def action_proj_func(noisy_actions): return self.action_in_proj(noisy_actions) action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) if not self.pi05: time_emb = time_emb[:, None, :].expand_as(action_emb) action_time_emb = torch.cat([action_emb, time_emb], dim=2) # Apply MLP layers def mlp_func(action_time_emb): x = self.action_time_mlp_in(action_time_emb) x = F.silu(x) # swish == silu return self.action_time_mlp_out(x) action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb) adarms_cond = None else: # time MLP (for adaRMS) def time_mlp_func(time_emb): x = self.time_mlp_in(time_emb) x = F.silu(x) # swish == silu x = self.time_mlp_out(x) return F.silu(x) time_emb = self._apply_checkpoint(time_mlp_func, time_emb) action_time_emb = action_emb adarms_cond = time_emb # Add to input tokens embs.append(action_time_emb) bsize, action_time_dim = action_time_emb.shape[:2] action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device) pad_masks.append(action_time_mask) # Set attention masks so that image, language and state inputs do not attend to action tokens att_masks += [1] + ([0] * (self.config.action_horizon - 1)) embs = torch.cat(embs, dim=1) pad_masks = torch.cat(pad_masks, dim=1) att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) att_masks = att_masks[None, :].expand(bsize, len(att_masks)) return embs, pad_masks, att_masks, adarms_cond def forward(self, observation, actions, noise=None, time=None) -> Tensor: """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=True) if noise is None: noise = self.sample_noise(actions.shape, actions.device) if time is None: time = self.sample_time(actions.shape[0], actions.device) time_expanded = time[:, None, None] x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks) suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time) if ( self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 ): suffix_embs = suffix_embs.to(dtype=torch.bfloat16) prefix_embs = prefix_embs.to(dtype=torch.bfloat16) pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) att_2d_masks = make_att_2d_masks(pad_masks, att_masks) position_ids = torch.cumsum(pad_masks, dim=1) - 1 # Prepare attention masks att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) # Apply gradient checkpointing if enabled def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): (_, suffix_out), _ = self.paligemma_with_expert.forward( attention_mask=att_2d_masks_4d, position_ids=position_ids, past_key_values=None, inputs_embeds=[prefix_embs, suffix_embs], use_cache=False, adarms_cond=[None, adarms_cond], ) return suffix_out suffix_out = self._apply_checkpoint( forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond ) suffix_out = suffix_out[:, -self.config.action_horizon :] suffix_out = suffix_out.to(dtype=torch.float32) # Apply gradient checkpointing to final action projection if enabled def action_out_proj_func(suffix_out): return self.action_out_proj(suffix_out) v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) return F.mse_loss(u_t, v_t, reduction="none") @torch.no_grad() def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor: """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" bsize = observation.state.shape[0] if noise is None: actions_shape = (bsize, self.config.action_horizon, self.config.action_dim) noise = self.sample_noise(actions_shape, device) images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=False) prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks) prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 # Compute image and language key value cache prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001 _, past_key_values = self.paligemma_with_expert.forward( attention_mask=prefix_att_2d_masks_4d, position_ids=prefix_position_ids, past_key_values=None, inputs_embeds=[prefix_embs, None], use_cache=True, ) dt = -1.0 / num_steps dt = torch.tensor(dt, dtype=torch.float32, device=device) x_t = noise time = torch.tensor(1.0, dtype=torch.float32, device=device) while time >= -dt / 2: expanded_time = time.expand(bsize) v_t = self.denoise_step( state, prefix_pad_masks, past_key_values, x_t, expanded_time, ) # Euler step - use new tensor assignment instead of in-place operation x_t = x_t + dt * v_t time += dt return x_t def denoise_step( self, state, prefix_pad_masks, past_key_values, x_t, timestep, ): """Apply one denoising step of the noise `x_t` at a given timestep.""" suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep) suffix_len = suffix_pad_masks.shape[1] batch_size = prefix_pad_masks.shape[0] prefix_len = prefix_pad_masks.shape[1] prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 # Prepare attention masks full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 outputs_embeds, _ = self.paligemma_with_expert.forward( attention_mask=full_att_2d_masks_4d, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=[None, suffix_embs], use_cache=False, adarms_cond=[None, adarms_cond], ) suffix_out = outputs_embeds[1] suffix_out = suffix_out[:, -self.config.action_horizon :] suffix_out = suffix_out.to(dtype=torch.float32) return self.action_out_proj(suffix_out) ================================================ FILE: src/openpi/models_pytorch/preprocessing_pytorch.py ================================================ from collections.abc import Sequence import logging import torch from openpi.shared import image_tools logger = logging.getLogger("openpi") # Constants moved from model.py IMAGE_KEYS = ( "base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb", ) IMAGE_RESOLUTION = (224, 224) def preprocess_observation_pytorch( observation, *, train: bool = False, image_keys: Sequence[str] = IMAGE_KEYS, image_resolution: tuple[int, int] = IMAGE_RESOLUTION, ): """Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations. This function avoids complex type annotations that can cause torch.compile issues. """ if not set(image_keys).issubset(observation.images): raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}") batch_shape = observation.state.shape[:-1] out_images = {} for key in image_keys: image = observation.images[key] # TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats # Handle both [B, C, H, W] and [B, H, W, C] formats is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1 if is_channels_first: # Convert [B, C, H, W] to [B, H, W, C] for processing image = image.permute(0, 2, 3, 1) if image.shape[1:3] != image_resolution: logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}") image = image_tools.resize_with_pad_torch(image, *image_resolution) if train: # Convert from [-1, 1] to [0, 1] for PyTorch augmentations image = image / 2.0 + 0.5 # Apply PyTorch-based augmentations if "wrist" not in key: # Geometric augmentations for non-wrist cameras height, width = image.shape[1:3] # Random crop and resize crop_height = int(height * 0.95) crop_width = int(width * 0.95) # Random crop max_h = height - crop_height max_w = width - crop_width if max_h > 0 and max_w > 0: # Use tensor operations instead of .item() for torch.compile compatibility start_h = torch.randint(0, max_h + 1, (1,), device=image.device) start_w = torch.randint(0, max_w + 1, (1,), device=image.device) image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :] # Resize back to original size image = torch.nn.functional.interpolate( image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] size=(height, width), mode="bilinear", align_corners=False, ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] # Random rotation (small angles) # Use tensor operations instead of .item() for torch.compile compatibility angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees if torch.abs(angle) > 0.1: # Only rotate if angle is significant # Convert to radians angle_rad = angle * torch.pi / 180.0 # Create rotation matrix cos_a = torch.cos(angle_rad) sin_a = torch.sin(angle_rad) # Apply rotation using grid_sample grid_x = torch.linspace(-1, 1, width, device=image.device) grid_y = torch.linspace(-1, 1, height, device=image.device) # Create meshgrid grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij") # Expand to batch dimension grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1) grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1) # Apply rotation transformation grid_x_rot = grid_x * cos_a - grid_y * sin_a grid_y_rot = grid_x * sin_a + grid_y * cos_a # Stack and reshape for grid_sample grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1) image = torch.nn.functional.grid_sample( image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] grid, mode="bilinear", padding_mode="zeros", align_corners=False, ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] # Color augmentations for all cameras # Random brightness # Use tensor operations instead of .item() for torch.compile compatibility brightness_factor = 0.7 + torch.rand(1, device=image.device) * 0.6 # Random factor between 0.7 and 1.3 image = image * brightness_factor # Random contrast # Use tensor operations instead of .item() for torch.compile compatibility contrast_factor = 0.6 + torch.rand(1, device=image.device) * 0.8 # Random factor between 0.6 and 1.4 mean = image.mean(dim=[1, 2, 3], keepdim=True) image = (image - mean) * contrast_factor + mean # Random saturation (convert to HSV, modify S, convert back) # For simplicity, we'll just apply a random scaling to the color channels # Use tensor operations instead of .item() for torch.compile compatibility saturation_factor = 0.5 + torch.rand(1, device=image.device) * 1.0 # Random factor between 0.5 and 1.5 gray = image.mean(dim=-1, keepdim=True) image = gray + (image - gray) * saturation_factor # Clamp values to [0, 1] image = torch.clamp(image, 0, 1) # Back to [-1, 1] image = image * 2.0 - 1.0 # Convert back to [B, C, H, W] format if it was originally channels-first if is_channels_first: image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] out_images[key] = image # obtain mask out_masks = {} for key in out_images: if key not in observation.image_masks: # do not mask by default out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device) else: out_masks[key] = observation.image_masks[key] # Create a simple object with the required attributes instead of using the complex Observation class class SimpleProcessedObservation: def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) return SimpleProcessedObservation( images=out_images, image_masks=out_masks, state=observation.state, tokenized_prompt=observation.tokenized_prompt, tokenized_prompt_mask=observation.tokenized_prompt_mask, token_ar_mask=observation.token_ar_mask, token_loss_mask=observation.token_loss_mask, ) ================================================ FILE: src/openpi/models_pytorch/transformers_replace/models/gemma/configuration_gemma.py ================================================ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/gemma/modular_gemma.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_gemma.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional from ...configuration_utils import PretrainedConfig class GemmaConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma-7B. e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 256000): Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`GemmaModel`] hidden_size (`int`, *optional*, defaults to 3072): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 24576): Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 28): Number of hidden layers in the Transformer decoder. num_attention_heads (`int`, *optional*, defaults to 16): Number of attention heads for each attention layer in the Transformer decoder. num_key_value_heads (`int`, *optional*, defaults to 16): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details, check out [this paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `num_attention_heads`. head_dim (`int`, *optional*, defaults to 256): The attention head dimension. hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): The legacy activation function. It is overwritten by the `hidden_activation`. hidden_activation (`str` or `function`, *optional*): The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. max_position_embeddings (`int`, *optional*, defaults to 8192): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. rms_norm_eps (`float`, *optional*, defaults to 1e-06): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. pad_token_id (`int`, *optional*, defaults to 0): Padding token id. eos_token_id (`int`, *optional*, defaults to 1): End of stream token id. bos_token_id (`int`, *optional*, defaults to 2): Beginning of stream token id. tie_word_embeddings (`bool`, *optional*, defaults to `True`): Whether to tie weight embeddings rope_theta (`float`, *optional*, defaults to 10000.0): The base period of the RoPE embeddings. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. use_adarms (`bool`, *optional*, defaults to `False`): Whether to use ADARMS. adarms_cond_dim (`int`, *optional*, defaults to `None`): The dimension of the ADARMS condition. ```python >>> from transformers import GemmaModel, GemmaConfig >>> # Initializing a Gemma gemma-7b style configuration >>> configuration = GemmaConfig() >>> # Initializing a model from the gemma-7b style configuration >>> model = GemmaModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "gemma" keys_to_ignore_at_inference = ["past_key_values"] base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } def __init__( self, vocab_size=256000, hidden_size=3072, intermediate_size=24576, num_hidden_layers=28, num_attention_heads=16, num_key_value_heads=16, head_dim=256, hidden_act="gelu_pytorch_tanh", hidden_activation=None, max_position_embeddings=8192, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=2, tie_word_embeddings=True, rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, use_adarms: bool = False, adarms_cond_dim: Optional[int] = None, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.head_dim = head_dim self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.hidden_activation = hidden_activation self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.use_adarms = use_adarms self.adarms_cond_dim = adarms_cond_dim # Set default for adarms_cond_dim if use_adarms is True if self.use_adarms and self.adarms_cond_dim is None: self.adarms_cond_dim = self.hidden_size super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) __all__ = ["GemmaConfig"] ================================================ FILE: src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py ================================================ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/gemma/modular_gemma.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_gemma.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Optional, Union import torch from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging from .configuration_gemma import GemmaConfig logger = logging.get_logger(__name__) class GemmaRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6, cond_dim: Optional[int] = None): super().__init__() self.eps = eps self.dim = dim self.cond_dim = cond_dim # Dense layer for adaptive normalization (if cond_dim is provided) if cond_dim is not None: #self.dense = nn.Linear(cond_dim, dim * 3, bias=True, dtype=torch.bfloat16) self.dense = nn.Linear(cond_dim, dim * 3, bias=True) # Initialize with zeros (matches source implementation) nn.init.zeros_(self.dense.weight) else: self.weight = nn.Parameter(torch.zeros(dim, dtype=torch.bfloat16)) self.dense = None def _norm(self, x): # Compute variance in float32 (like the source implementation) var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True) # Compute normalization in float32 normed_inputs = x * torch.rsqrt(var + self.eps) return normed_inputs def forward(self, x, cond=None): dtype = x.dtype # original dtype, could be half-precision normed_inputs = self._norm(x) if cond is None or self.dense is None: # regular RMSNorm # scale by learned parameter in float32 (matches source implementation) normed_inputs = normed_inputs * (1.0 + self.weight.float()) return normed_inputs.to(dtype), None # return in original dtype with None gate # adaptive RMSNorm (if cond is provided and dense layer exists) if cond.shape[-1] != self.cond_dim: raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}") #self.dense.to(dtype=torch.bfloat16).to(dtype=torch.float32) modulation = self.dense(cond) # Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features] if len(x.shape) == 3: # [batch, seq, features] modulation = modulation.unsqueeze(1) scale, shift, gate = torch.chunk(modulation, 3, dim=-1) # Apply adaptive normalization: use model weight dtype to ensure compatibility # model_dtype = self.dense.weight.dtype # Use the model's dtype (bfloat16) # scale = scale.to(model_dtype) # shift = shift.to(model_dtype) # gate = gate.to(model_dtype) # normed_inputs = normed_inputs.to(model_dtype) # Convert normed_inputs to model dtype normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to(torch.float32) return normed_inputs.to(dtype), gate.to(dtype) def extra_repr(self): repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}" if self.dense is not None: repr_str += f", adaptive=True, cond_dim={self.cond_dim}" return repr_str class GemmaMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj class GemmaRotaryEmbedding(nn.Module): def __init__(self, config: GemmaConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def _gated_residual(x, y, gate): """ Applies gated residual connection with optional gate parameter. Args: x: Input tensor (residual) y: Output tensor to be added gate: Optional gate tensor to modulate the addition Returns: x + y if gate is None, otherwise x + y * gate """ if x is None and y is None: return None if x is None or y is None: return x if x is not None else y if gate is None: return x + y return x + y * gate def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class GemmaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: GemmaConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, use_cache: bool = False, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Use cache if provided if past_key_value is not None: if use_cache: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) else: key_states = torch.cat([past_key_value[self.layer_idx][0], key_states], dim=2) value_states = torch.cat([past_key_value[self.layer_idx][1], value_states], dim=2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class GemmaDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: GemmaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx) self.mlp = GemmaMLP(config) cond_dim = getattr(config, 'adarms_cond_dim', None) if getattr(config, 'use_adarms', False) else None self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim) self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC adarms_cond: Optional[torch.Tensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states, gate = self.input_layernorm(hidden_states, adarms_cond) # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = _gated_residual(residual, hidden_states, gate) # Fully Connected residual = hidden_states hidden_states, gate = self.post_attention_layernorm(hidden_states, adarms_cond) hidden_states = self.mlp(hidden_states) hidden_states = _gated_residual(residual, hidden_states, gate) outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) return outputs @auto_docstring class GemmaPreTrainedModel(PreTrainedModel): config_class = GemmaConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["GemmaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_3 = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, GemmaRMSNorm): if hasattr(module, 'weight'): module.weight.data.fill_(1.0) @auto_docstring class GemmaModel(GemmaPreTrainedModel): def __init__(self, config: GemmaConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) cond_dim = getattr(config, 'adarms_cond_dim', None) if getattr(config, 'use_adarms', False) else None self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim) self.rotary_emb = GemmaRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, adarms_cond: Optional[torch.Tensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: """ adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): Condition for ADARMS. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = create_causal_mask( config=self.config, input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, position_ids=position_ids, ) # embed positions hidden_states = inputs_embeds # Convert to bfloat16 if the first layer uses bfloat16 if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16: hidden_states = hidden_states.to(torch.bfloat16) # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # normalized # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) #hidden_states = hidden_states * normalizer # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, adarms_cond=adarms_cond, **kwargs, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states, _ = self.norm(hidden_states, adarms_cond) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @auto_docstring class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) self.model = GemmaModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, adarms_cond: Optional[torch.Tensor] = None, **kwargs: Unpack[KwargsForCausalLM], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): Condition for ADARMS. Example: ```python >>> from transformers import AutoTokenizer, GemmaForCausalLM >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") >>> prompt = "What is your favorite condiment?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "What is your favorite condiment?" ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, adarms_cond=adarms_cond, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @auto_docstring( custom_intro=""" The Gemma Model transformer with a sequence classification head on top (linear layer). [`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """ ) class GemmaForSequenceClassification(GemmaPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = GemmaModel(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, adarms_cond: Optional[torch.Tensor] = None, ) -> SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): Condition for ADARMS. """ transformer_outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, adarms_cond=adarms_cond, ) hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 logger.warning_once( f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " "unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) @auto_docstring class GemmaForTokenClassification(GemmaPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = GemmaModel(config) if getattr(config, "classifier_dropout", None) is not None: classifier_dropout = config.classifier_dropout elif getattr(config, "hidden_dropout", None) is not None: classifier_dropout = config.hidden_dropout else: classifier_dropout = 0.1 self.dropout = nn.Dropout(classifier_dropout) self.score = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, adarms_cond: Optional[torch.Tensor] = None, ) -> TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): Condition for ADARMS. """ outputs: BaseModelOutputWithPast = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, adarms_cond=adarms_cond, ) sequence_output = outputs.last_hidden_state sequence_output = self.dropout(sequence_output) logits = self.score(sequence_output) loss = None if labels is not None: loss = self.loss_function(logits, labels, self.config) return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) __all__ = [ "GemmaModel", "GemmaForCausalLM", "GemmaForSequenceClassification", "GemmaForTokenClassification", "GemmaPreTrainedModel", ] ================================================ FILE: src/openpi/models_pytorch/transformers_replace/models/paligemma/modeling_paligemma.py ================================================ # coding=utf-8 # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch PaliGemmamodel.""" from dataclasses import dataclass from typing import Optional, Union import torch import torch.utils.checkpoint from torch import nn from ...cache_utils import Cache, HybridCache, StaticCache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import LossKwargs, ModelOutput, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging from ..auto import AutoModel from .configuration_paligemma import PaliGemmaConfig logger = logging.get_logger(__name__) @dataclass @auto_docstring( custom_intro=""" Base class for Paligemma outputs, with hidden states and attentions. """ ) class PaligemmaModelOutputWithPast(BaseModelOutputWithPast): r""" past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. """ image_hidden_states: Optional[torch.FloatTensor] = None @dataclass @auto_docstring( custom_intro=""" Base class for PaliGemma causal language model (or autoregressive) outputs. """ ) class PaliGemmaCausalLMOutputWithPast(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. image_hidden_states (`torch.FloatTensor`, *optional*): A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder after projecting last hidden state. """ loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None image_hidden_states: Optional[torch.FloatTensor] = None class PaliGemmaMultiModalProjector(nn.Module): def __init__(self, config: PaliGemmaConfig): super().__init__() self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True) def forward(self, image_features): hidden_states = self.linear(image_features) return hidden_states @auto_docstring class PaliGemmaPreTrainedModel(PreTrainedModel): config_class = PaliGemmaConfig base_model_prefix = "" supports_gradient_checkpointing = True _no_split_modules = ["PaliGemmaMultiModalProjector"] _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True _supports_attention_backend = True def _init_weights(self, module): # important: this ported version of PaliGemmaisn't meant for training from scratch - only # inference and fine-tuning std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() @auto_docstring( custom_intro=""" The Base Paligemma model which consists of a vision backbone and a language model withou language modeling head., """ ) class PaliGemmaModel(PaliGemmaPreTrainedModel): _checkpoint_conversion_mapping = {"language_model.model": "language_model"} # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch accepts_loss_kwargs = False def __init__(self, config: PaliGemmaConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config=config.vision_config) self.multi_modal_projector = PaliGemmaMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size language_model = AutoModel.from_config(config=config.text_config) self.language_model = language_model self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() # Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma def get_input_embeddings(self): return self.language_model.get_input_embeddings() # Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def set_decoder(self, decoder): self.language_model = decoder def get_decoder(self): return self.language_model def _update_causal_mask( self, attention_mask, token_type_ids=None, past_key_values=None, cache_position=None, input_tensor=None, is_training: Optional[bool] = None, ): if self.config.text_config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None is_training = is_training if is_training is not None else self.training using_static_cache = isinstance(past_key_values, StaticCache) min_dtype = torch.finfo(self.dtype).min if input_tensor is None: input_tensor = attention_mask inputs_lead_dim, sequence_length = input_tensor.shape[:2] if using_static_cache: target_length = past_key_values.get_max_cache_shape() elif isinstance(past_key_values, HybridCache): target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[0] + sequence_length + 1 ) if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. return attention_mask causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device ) # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below if sequence_length != 1: if is_training: causal_mask = torch.triu(causal_mask, diagonal=1) else: causal_mask[:, :sequence_length] = 0.0 causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] # First unmask prefix tokens during training if is_training: if token_type_ids is None: raise ValueError("Token type ids must be provided during training") causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 ) # Then apply padding mask (will mask pad tokens) padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) return causal_mask def get_image_features(self, pixel_values: torch.FloatTensor): """ Obtains image last hidden states from the vision tower and apply multimodal projection. Args: pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) The tensors corresponding to the input images. Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ image_outputs = self.vision_tower(pixel_values) selected_image_feature = image_outputs.last_hidden_state image_features = self.multi_modal_projector(selected_image_feature) return image_features @can_return_tuple @auto_docstring def forward( self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, token_type_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[tuple, PaligemmaModelOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. Example: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224") >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224") >>> prompt = "Where is the cat standing?" >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, text=prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(**inputs,) >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Where is the cat standing?\nsnow" ```""" if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict is_training = token_type_ids is not None and labels is not None # Replace image id woth PAD if the image token if OOV, to avoid index-errors if input_ids is not None and self.config.image_token_id >= self.vocab_size: special_image_mask = input_ids == self.config.image_token_id llm_input_ids = input_ids.clone() llm_input_ids[special_image_mask] = 0 else: llm_input_ids = input_ids if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(llm_input_ids) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed # Merge text and images if pixel_values is not None: image_features = self.get_image_features(pixel_values) if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) else: special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] raise ValueError( f"Number of images does not match number of special image tokens in the input text. " f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " "tokens from image embeddings." ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) causal_mask = self._update_causal_mask( attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training ) outputs = self.language_model( attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, **kwargs, ) return PaligemmaModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @auto_docstring( custom_intro=""" The Base Paligemma model which consists of a vision backbone and a language model without language modeling head., """ ) class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = { "^language_model.model": "model.language_model", "^vision_tower": "model.vision_tower", "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: PaliGemmaConfig): super().__init__(config) self.model = PaliGemmaModel(config) self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model.set_decoder(decoder) def get_decoder(self): return self.model.get_decoder() def get_image_features(self, pixel_values): return self.model.get_image_features(pixel_values) # Make modules available throught conditional class for BC @property def language_model(self): return self.model.language_model @property def vision_tower(self): return self.model.vision_tower @property def multi_modal_projector(self): return self.model.multi_modal_projector @can_return_tuple @auto_docstring def forward( self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, token_type_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[tuple, PaliGemmaCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. Example: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224") >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224") >>> prompt = "Where is the cat standing?" >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, text=prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(**inputs,) >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Where is the cat standing?\nsnow" ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, token_type_ids=token_type_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, labels=labels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, **kwargs, ) hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function( logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs ) return PaliGemmaCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=outputs.image_hidden_states, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, cache_position=None, position_ids=None, pixel_values=None, attention_mask=None, token_type_ids=None, use_cache=True, logits_to_keep=None, labels=None, **kwargs, ): # Overwritten -- custom `position_ids` and `pixel_values` handling model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, cache_position=cache_position, use_cache=use_cache, logits_to_keep=logits_to_keep, token_type_ids=token_type_ids, **kwargs, ) # position_ids in Paligemma are 1-indexed if model_inputs.get("position_ids") is not None: model_inputs["position_ids"] += 1 # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always if cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values is_training = token_type_ids is not None and labels is not None if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): input_tensor = inputs_embeds if inputs_embeds is not None else input_ids causal_mask = self.model._update_causal_mask( attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training ) model_inputs["attention_mask"] = causal_mask return model_inputs @staticmethod # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, cache_position: torch.Tensor, batch_size: int, **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( causal_mask.device ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) return causal_mask __all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"] ================================================ FILE: src/openpi/models_pytorch/transformers_replace/models/siglip/check.py ================================================ import transformers def check_whether_transformers_replace_is_installed_correctly(): return transformers.__version__ == "4.53.2" ================================================ FILE: src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py ================================================ # coding=utf-8 # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Siglip model.""" import math import warnings from dataclasses import dataclass from typing import Any, Callable, Optional, Union import numpy as np import torch import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn.init import _calculate_fan_in_and_fan_out from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig logger = logging.get_logger(__name__) def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2, ) # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) def trunc_normal_tf_( tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 ) -> torch.Tensor: """Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \\leq \text{mean} \\leq b`. NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 and the result is subsequently scaled and shifted by the mean and std args. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value """ with torch.no_grad(): _trunc_normal_(tensor, 0, 1.0, a, b) tensor.mul_(std).add_(mean) def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == "fan_in": denom = fan_in elif mode == "fan_out": denom = fan_out elif mode == "fan_avg": denom = (fan_in + fan_out) / 2 variance = scale / denom if distribution == "truncated_normal": # constant is stddev of standard normal truncated to (-2, 2) trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) elif distribution == "normal": with torch.no_grad(): tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) with torch.no_grad(): tensor.uniform_(-bound, bound) else: raise ValueError(f"invalid distribution {distribution}") def lecun_normal_(tensor): variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") def default_flax_embed_init(tensor): variance_scaling_(tensor, mode="fan_in", distribution="normal") @dataclass @auto_docstring( custom_intro=""" Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. """ ) # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip class SiglipVisionModelOutput(ModelOutput): r""" image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): The image embeddings obtained by applying the projection layer to the pooler_output. """ image_embeds: Optional[torch.FloatTensor] = None last_hidden_state: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None attentions: Optional[tuple[torch.FloatTensor, ...]] = None @dataclass @auto_docstring( custom_intro=""" Base class for text model's outputs that also contains a pooling of the last hidden states. """ ) # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip class SiglipTextModelOutput(ModelOutput): r""" text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): The text embeddings obtained by applying the projection layer to the pooler_output. """ text_embeds: Optional[torch.FloatTensor] = None last_hidden_state: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None attentions: Optional[tuple[torch.FloatTensor, ...]] = None @dataclass @auto_docstring # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip class SiglipOutput(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): Contrastive loss for image-text similarity. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text similarity scores. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image similarity scores. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. text_model_output (`BaseModelOutputWithPooling`): The output of the [`SiglipTextModel`]. vision_model_output (`BaseModelOutputWithPooling`): The output of the [`SiglipVisionModel`]. """ loss: Optional[torch.FloatTensor] = None logits_per_image: Optional[torch.FloatTensor] = None logits_per_text: Optional[torch.FloatTensor] = None text_embeds: Optional[torch.FloatTensor] = None image_embeds: Optional[torch.FloatTensor] = None text_model_output: BaseModelOutputWithPooling = None vision_model_output: BaseModelOutputWithPooling = None def to_tuple(self) -> tuple[Any]: return tuple( self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() for k in self.keys() ) class SiglipVisionEmbeddings(nn.Module): def __init__(self, config: SiglipVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, padding="valid", ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. This method is also adapted to support torch.jit tracing and no class embeddings. Adapted from: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] num_positions = self.position_embedding.weight.shape[0] # always interpolate when tracing to ensure the exported model works for dynamic input shapes if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embedding(self.position_ids) patch_pos_embed = self.position_embedding.weight.unsqueeze(0) dim = embeddings.shape[-1] new_height = height // self.patch_size new_width = width // self.patch_size sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, size=(new_height, new_width), mode="bicubic", align_corners=False, ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return patch_pos_embed def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: _, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] embeddings = patch_embeds.flatten(2).transpose(1, 2) if interpolate_pos_encoding: embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) else: embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip class SiglipTextEmbeddings(nn.Module): def __init__(self, config: SiglipTextConfig): super().__init__() embed_dim = config.hidden_size self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) def forward( self, input_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] max_position_embedding = self.position_embedding.weight.shape[0] if seq_length > max_position_embedding: raise ValueError( f"Sequence length must be less than max_position_embeddings (got `sequence length`: " f"{seq_length} and max_position_embeddings: {max_position_embedding}" ) if position_ids is None: position_ids = self.position_ids[:, :seq_length] if inputs_embeds is None: inputs_embeds = self.token_embedding(input_ids) position_embeddings = self.position_embedding(position_ids) embeddings = inputs_embeds + position_embeddings return embeddings def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class SiglipAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" batch_size, seq_length, embed_dim = hidden_states.shape queries = self.q_proj(hidden_states) keys = self.k_proj(hidden_states) values = self.v_proj(hidden_states) queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, queries, keys, values, attention_mask, is_causal=self.is_causal, scaling=self.scale, dropout=0.0 if not self.training else self.dropout, ) attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() attn_output = self.out_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip class SiglipMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class SiglipEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]): super().__init__() self.embed_dim = config.hidden_size self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.self_attn = SiglipAttention(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP(config) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: Optional[bool] = False, ) -> tuple[torch.FloatTensor]: """ Args: hidden_states (`torch.FloatTensor`): Input to the layer of shape `(batch, seq_len, embed_dim)`. attention_mask (`torch.FloatTensor`): Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs @auto_docstring class SiglipPreTrainedModel(PreTrainedModel): config_class = SiglipConfig base_model_prefix = "siglip" supports_gradient_checkpointing = True _no_split_modules = [ "SiglipTextEmbeddings", "SiglipEncoderLayer", "SiglipVisionEmbeddings", "SiglipEncoderLayer", "SiglipMultiheadAttentionPoolingHead", ] _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True _supports_attention_backend = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, SiglipVisionEmbeddings): width = ( self.config.vision_config.hidden_size if isinstance(self.config, SiglipConfig) else self.config.hidden_size ) nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) elif isinstance(module, nn.Embedding): default_flax_embed_init(module.weight) elif isinstance(module, SiglipAttention): nn.init.xavier_uniform_(module.q_proj.weight) nn.init.xavier_uniform_(module.k_proj.weight) nn.init.xavier_uniform_(module.v_proj.weight) nn.init.xavier_uniform_(module.out_proj.weight) nn.init.zeros_(module.q_proj.bias) nn.init.zeros_(module.k_proj.bias) nn.init.zeros_(module.v_proj.bias) nn.init.zeros_(module.out_proj.bias) elif isinstance(module, SiglipMLP): nn.init.xavier_uniform_(module.fc1.weight) nn.init.xavier_uniform_(module.fc2.weight) nn.init.normal_(module.fc1.bias, std=1e-6) nn.init.normal_(module.fc2.bias, std=1e-6) elif isinstance(module, SiglipMultiheadAttentionPoolingHead): nn.init.xavier_uniform_(module.probe.data) nn.init.xavier_uniform_(module.attention.in_proj_weight.data) nn.init.zeros_(module.attention.in_proj_bias.data) elif isinstance(module, SiglipModel): logit_scale_init = torch.log(torch.tensor(1.0)) module.logit_scale.data.fill_(logit_scale_init) module.logit_bias.data.zero_() elif isinstance(module, SiglipForImageClassification): nn.init.normal_( module.classifier.weight, std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, ) elif isinstance(module, (nn.Linear, nn.Conv2d)): lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip class SiglipEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`SiglipEncoderLayer`]. Args: config: SiglipConfig """ def __init__(self, config: SiglipConfig): super().__init__() self.config = config self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False # Ignore copy @can_return_tuple def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> BaseModelOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None hidden_states = inputs_embeds for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) layer_outputs = encoder_layer( hidden_states, attention_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions, ) class SiglipTextTransformer(nn.Module): def __init__(self, config: SiglipTextConfig): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = SiglipTextEmbeddings(config) self.encoder = SiglipEncoder(config) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.head = nn.Linear(embed_dim, config.projection_size) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> BaseModelOutputWithPooling: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) if input_ids is None: raise ValueError("You have to specify input_ids") input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. # expand attention_mask if attention_mask is not None and not self._use_flash_attention_2: # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.final_layer_norm(last_hidden_state) # Assuming "sticky" EOS tokenization, last token is always EOS. pooled_output = last_hidden_state[:, -1, :] pooled_output = self.head(pooled_output) return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) @auto_docstring( custom_intro=""" The text model from SigLIP without any head or projection on top. """ ) class SiglipTextModel(SiglipPreTrainedModel): config_class = SiglipTextConfig def __init__(self, config: SiglipTextConfig): super().__init__(config) self.text_model = SiglipTextTransformer(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.text_model.embeddings.token_embedding def set_input_embeddings(self, value): self.text_model.embeddings.token_embedding = value @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> BaseModelOutputWithPooling: r""" Examples: ```python >>> from transformers import AutoTokenizer, SiglipTextModel >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") >>> # important: make sure to set padding="max_length" as that's how the model was trained >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled (EOS token) states ```""" return self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) class SiglipVisionTransformer(nn.Module): def __init__(self, config: SiglipVisionConfig): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = SiglipVisionEmbeddings(config) self.encoder = SiglipEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head if self.use_head: self.head = SiglipMultiheadAttentionPoolingHead(config) @can_return_tuple @auto_docstring def forward( self, pixel_values, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: Optional[bool] = False, ) -> BaseModelOutputWithPooling: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) # Convert to bfloat16 if the encoder uses bfloat16 if len(self.encoder.layers) > 0 and self.encoder.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16: hidden_states = hidden_states.to(torch.bfloat16) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.post_layernorm(last_hidden_state) pooler_output = self.head(last_hidden_state) if self.use_head else None return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooler_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class SiglipMultiheadAttentionPoolingHead(nn.Module): """Multihead Attention Pooling.""" def __init__(self, config: SiglipVisionConfig): super().__init__() self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = SiglipMLP(config) def forward(self, hidden_state): batch_size = hidden_state.shape[0] probe = self.probe.repeat(batch_size, 1, 1) hidden_state = self.attention(probe, hidden_state, hidden_state)[0] residual = hidden_state hidden_state = self.layernorm(hidden_state) hidden_state = residual + self.mlp(hidden_state) return hidden_state[:, 0] @auto_docstring( custom_intro=""" The vision model from SigLIP without any head or projection on top. """ ) class SiglipVisionModel(SiglipPreTrainedModel): config_class = SiglipVisionConfig main_input_name = "pixel_values" def __init__(self, config: SiglipVisionConfig): super().__init__(config) self.vision_model = SiglipVisionTransformer(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding @can_return_tuple @auto_docstring def forward( self, pixel_values, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> BaseModelOutputWithPooling: r""" Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, SiglipVisionModel >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled features ```""" return self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, ) @auto_docstring class SiglipModel(SiglipPreTrainedModel): config_class = SiglipConfig def __init__(self, config: SiglipConfig): super().__init__(config) if not isinstance(config.text_config, SiglipTextConfig): raise TypeError( "config.text_config is expected to be of type SiglipTextConfig but is of type" f" {type(config.text_config)}." ) if not isinstance(config.vision_config, SiglipVisionConfig): raise TypeError( "config.vision_config is expected to be of type SiglipVisionConfig but is of type" f" {type(config.vision_config)}." ) text_config = config.text_config vision_config = config.vision_config # First, initialize the text and vision models with proper attention implementation text_model = SiglipTextModel._from_config(text_config) vision_model = SiglipVisionModel._from_config(vision_config) # Second, get the text and vision submodules (for backward compatibility) self.text_model = text_model.text_model self.vision_model = vision_model.vision_model self.logit_scale = nn.Parameter(torch.randn(1)) self.logit_bias = nn.Parameter(torch.randn(1)) # Initialize weights and apply final processing self.post_init() @auto_docstring def get_text_features( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> torch.FloatTensor: r""" Returns: text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. Examples: ```python >>> from transformers import AutoTokenizer, AutoModel >>> import torch >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") >>> # important: make sure to set padding="max_length" as that's how the model was trained >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") >>> with torch.no_grad(): ... text_features = model.get_text_features(**inputs) ```""" # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) text_outputs: BaseModelOutputWithPooling = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) pooled_output = text_outputs.pooler_output return pooled_output @auto_docstring def get_image_features( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> torch.FloatTensor: r""" Returns: image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, AutoModel >>> import torch >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="pt") >>> with torch.no_grad(): ... image_features = model.get_image_features(**inputs) ```""" # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) vision_outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, ) pooled_output = vision_outputs.pooler_output return pooled_output @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> SiglipOutput: r""" return_loss (`bool`, *optional*): Whether or not to return the contrastive loss. Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, AutoModel >>> import torch >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] >>> # important: we pass `padding=max_length` since the model was trained with this >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") >>> with torch.no_grad(): ... outputs = model(**inputs) >>> logits_per_image = outputs.logits_per_image >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") 31.9% that image 0 is 'a photo of 2 cats' ```""" # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) vision_outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, ) text_outputs: BaseModelOutputWithPooling = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) image_embeds = vision_outputs.pooler_output text_embeds = text_outputs.pooler_output # normalized features image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) # cosine similarity as logits logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device) logits_per_text = logits_per_text * logit_scale.exp() + logit_bias logits_per_image = logits_per_text.t() loss = None if return_loss: # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287 eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device) m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text) nll = -torch.sum(loglik, dim=-1) loss = nll.mean() return SiglipOutput( loss=loss, logits_per_image=logits_per_image, logits_per_text=logits_per_text, text_embeds=text_embeds, image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, ) @auto_docstring( custom_intro=""" SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of the patch tokens) e.g. for ImageNet. """ ) class SiglipForImageClassification(SiglipPreTrainedModel): main_input_name = "pixel_values" def __init__(self, config: SiglipConfig) -> None: super().__init__(config) self.num_labels = config.num_labels # Create the vision model with proper attention # and take only vision_model submodule (for backward compatibility) vision_model = SiglipVisionModel._from_config(config.vision_config) self.vision_model = vision_model.vision_model # Classifier head self.classifier = ( nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() ) # Initialize weights and apply final processing self.post_init() @can_return_tuple @auto_docstring def forward( self, pixel_values: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> ImageClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the image classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). Examples: ```python >>> from transformers import AutoImageProcessor, SiglipForImageClassification >>> import torch >>> from PIL import Image >>> import requests >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> # note: we are loading a `SiglipModel` from the hub here, >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above. >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224") >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224") >>> inputs = image_processor(images=image, return_tensors="pt") >>> outputs = model(**inputs) >>> logits = outputs.logits >>> # model predicts one of the two classes >>> predicted_class_idx = logits.argmax(-1).item() >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) Predicted class: LABEL_1 ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, ) sequence_output = outputs.last_hidden_state # average pool the patch tokens sequence_output = torch.mean(sequence_output, dim=1) # apply classifier logits = self.classifier(sequence_output) loss = None if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) return ImageClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) __all__ = [ "SiglipModel", "SiglipPreTrainedModel", "SiglipTextModel", "SiglipVisionModel", "SiglipForImageClassification", ] ================================================ FILE: src/openpi/policies/aloha_policy.py ================================================ import dataclasses from typing import ClassVar import einops import numpy as np from openpi import transforms def make_aloha_example() -> dict: """Creates a random input example for the Aloha policy.""" return { "state": np.ones((14,)), "images": { "cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), "cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), "cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), "cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), }, "prompt": "do something", } @dataclasses.dataclass(frozen=True) class AlohaInputs(transforms.DataTransformFn): """Inputs for the Aloha policy. Expected inputs: - images: dict[name, img] where img is [channel, height, width]. name must be in EXPECTED_CAMERAS. - state: [14] - actions: [action_horizon, 14] """ # If true, this will convert the joint and gripper values from the standard Aloha space to # the space used by the pi internal runtime which was used to train the base model. adapt_to_pi: bool = True # The expected cameras names. All input cameras must be in this set. Missing cameras will be # replaced with black images and the corresponding `image_mask` will be set to False. EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist") def __call__(self, data: dict) -> dict: data = _decode_aloha(data, adapt_to_pi=self.adapt_to_pi) in_images = data["images"] if set(in_images) - set(self.EXPECTED_CAMERAS): raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}") # Assume that base image always exists. base_image = in_images["cam_high"] images = { "base_0_rgb": base_image, } image_masks = { "base_0_rgb": np.True_, } # Add the extra images. extra_image_names = { "left_wrist_0_rgb": "cam_left_wrist", "right_wrist_0_rgb": "cam_right_wrist", } for dest, source in extra_image_names.items(): if source in in_images: images[dest] = in_images[source] image_masks[dest] = np.True_ else: images[dest] = np.zeros_like(base_image) image_masks[dest] = np.False_ inputs = { "image": images, "image_mask": image_masks, "state": data["state"], } # Actions are only available during training. if "actions" in data: actions = np.asarray(data["actions"]) actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi) inputs["actions"] = actions if "prompt" in data: inputs["prompt"] = data["prompt"] return inputs @dataclasses.dataclass(frozen=True) class AlohaOutputs(transforms.DataTransformFn): """Outputs for the Aloha policy.""" # If true, this will convert the joint and gripper values from the standard Aloha space to # the space used by the pi internal runtime which was used to train the base model. adapt_to_pi: bool = True def __call__(self, data: dict) -> dict: # Only return the first 14 dims. actions = np.asarray(data["actions"][:, :14]) return {"actions": _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)} def _joint_flip_mask() -> np.ndarray: """Used to convert between aloha and pi joint angles.""" return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1]) def _normalize(x, min_val, max_val): return (x - min_val) / (max_val - min_val) def _unnormalize(x, min_val, max_val): return x * (max_val - min_val) + min_val def _gripper_to_angular(value): # Aloha transforms the gripper positions into a linear space. The following code # reverses this transformation to be consistent with pi0 which is pretrained in # angular space. # # These values are coming from the Aloha code: # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED value = _unnormalize(value, min_val=0.01844, max_val=0.05800) # This is the inverse of the angular to linear transformation inside the Interbotix code. def linear_to_radian(linear_position, arm_length, horn_radius): value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) return np.arcsin(np.clip(value, -1.0, 1.0)) # The constants are taken from the Interbotix code. value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) # pi0 gripper data is normalized (0, 1) between encoder counts (2405, 3110). # There are 4096 total encoder counts and aloha uses a zero of 2048. # Converting this to radians means that the normalized inputs are between (0.5476, 1.6296) return _normalize(value, min_val=0.5476, max_val=1.6296) def _gripper_from_angular(value): # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha. # Note that the units are still angular but the range is different. # We do not scale the output since the trossen model predictions are already in radians. # See the comment in _gripper_to_angular for a derivation of the constant value = value + 0.5476 # These values are coming from the Aloha code: # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE return _normalize(value, min_val=-0.6213, max_val=1.4910) def _gripper_from_angular_inv(value): # Directly inverts the gripper_from_angular function. value = _unnormalize(value, min_val=-0.6213, max_val=1.4910) return value - 0.5476 def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict: # state is [left_arm_joint_angles, left_arm_gripper, right_arm_joint_angles, right_arm_gripper] # dim sizes: [6, 1, 6, 1] state = np.asarray(data["state"]) state = _decode_state(state, adapt_to_pi=adapt_to_pi) def convert_image(img): img = np.asarray(img) # Convert to uint8 if using float images. if np.issubdtype(img.dtype, np.floating): img = (255 * img).astype(np.uint8) # Convert from [channel, height, width] to [height, width, channel]. return einops.rearrange(img, "c h w -> h w c") images = data["images"] images_dict = {name: convert_image(img) for name, img in images.items()} data["images"] = images_dict data["state"] = state return data def _decode_state(state: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray: if adapt_to_pi: # Flip the joints. state = _joint_flip_mask() * state # Reverse the gripper transformation that is being applied by the Aloha runtime. state[[6, 13]] = _gripper_to_angular(state[[6, 13]]) return state def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray: if adapt_to_pi: # Flip the joints. actions = _joint_flip_mask() * actions actions[:, [6, 13]] = _gripper_from_angular(actions[:, [6, 13]]) return actions def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray: if adapt_to_pi: actions = _joint_flip_mask() * actions actions[:, [6, 13]] = _gripper_from_angular_inv(actions[:, [6, 13]]) return actions ================================================ FILE: src/openpi/policies/droid_policy.py ================================================ import dataclasses import einops import numpy as np from openpi import transforms from openpi.models import model as _model def make_droid_example() -> dict: """Creates a random input example for the Droid policy.""" return { "observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), "observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), "observation/joint_position": np.random.rand(7), "observation/gripper_position": np.random.rand(1), "prompt": "do something", } def _parse_image(image) -> np.ndarray: image = np.asarray(image) if np.issubdtype(image.dtype, np.floating): image = (255 * image).astype(np.uint8) if image.shape[0] == 3: image = einops.rearrange(image, "c h w -> h w c") return image @dataclasses.dataclass(frozen=True) class DroidInputs(transforms.DataTransformFn): # Determines which model will be used. model_type: _model.ModelType def __call__(self, data: dict) -> dict: gripper_pos = np.asarray(data["observation/gripper_position"]) if gripper_pos.ndim == 0: # Ensure gripper position is a 1D array, not a scalar, so we can concatenate with joint positions gripper_pos = gripper_pos[np.newaxis] state = np.concatenate([data["observation/joint_position"], gripper_pos]) # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically # stores as float32 (C,H,W), gets skipped for policy inference base_image = _parse_image(data["observation/exterior_image_1_left"]) wrist_image = _parse_image(data["observation/wrist_image_left"]) match self.model_type: case _model.ModelType.PI0 | _model.ModelType.PI05: names = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb") images = (base_image, wrist_image, np.zeros_like(base_image)) image_masks = (np.True_, np.True_, np.False_) case _model.ModelType.PI0_FAST: names = ("base_0_rgb", "base_1_rgb", "wrist_0_rgb") # We don't mask out padding images for FAST models. images = (base_image, np.zeros_like(base_image), wrist_image) image_masks = (np.True_, np.True_, np.True_) case _: raise ValueError(f"Unsupported model type: {self.model_type}") inputs = { "state": state, "image": dict(zip(names, images, strict=True)), "image_mask": dict(zip(names, image_masks, strict=True)), } if "actions" in data: inputs["actions"] = np.asarray(data["actions"]) if "prompt" in data: if isinstance(data["prompt"], bytes): data["prompt"] = data["prompt"].decode("utf-8") inputs["prompt"] = data["prompt"] return inputs @dataclasses.dataclass(frozen=True) class DroidOutputs(transforms.DataTransformFn): def __call__(self, data: dict) -> dict: # Only return the first 8 dims. return {"actions": np.asarray(data["actions"][:, :8])} ================================================ FILE: src/openpi/policies/libero_policy.py ================================================ import dataclasses import einops import numpy as np from openpi import transforms from openpi.models import model as _model def make_libero_example() -> dict: """Creates a random input example for the Libero policy.""" return { "observation/state": np.random.rand(8), "observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), "observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), "prompt": "do something", } def _parse_image(image) -> np.ndarray: image = np.asarray(image) if np.issubdtype(image.dtype, np.floating): image = (255 * image).astype(np.uint8) if image.shape[0] == 3: image = einops.rearrange(image, "c h w -> h w c") return image @dataclasses.dataclass(frozen=True) class LiberoInputs(transforms.DataTransformFn): """ This class is used to convert inputs to the model to the expected format. It is used for both training and inference. For your own dataset, you can copy this class and modify the keys based on the comments below to pipe the correct elements of your dataset into the model. """ # Determines which model will be used. # Do not change this for your own dataset. model_type: _model.ModelType def __call__(self, data: dict) -> dict: # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically # stores as float32 (C,H,W), gets skipped for policy inference. # Keep this for your own dataset, but if your dataset stores the images # in a different key than "observation/image" or "observation/wrist_image", # you should change it below. # Pi0 models support three image inputs at the moment: one third-person view, # and two wrist views (left and right). If your dataset does not have a particular type # of image, e.g. wrist images, you can comment it out here and replace it with zeros like we do for the # right wrist image below. base_image = _parse_image(data["observation/image"]) wrist_image = _parse_image(data["observation/wrist_image"]) # Create inputs dict. Do not change the keys in the dict below. inputs = { "state": data["observation/state"], "image": { "base_0_rgb": base_image, "left_wrist_0_rgb": wrist_image, # Pad any non-existent images with zero-arrays of the appropriate shape. "right_wrist_0_rgb": np.zeros_like(base_image), }, "image_mask": { "base_0_rgb": np.True_, "left_wrist_0_rgb": np.True_, # We only mask padding images for pi0 model, not pi0-FAST. Do not change this for your own dataset. "right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_, }, } # Pad actions to the model action dimension. Keep this for your own dataset. # Actions are only available during training. if "actions" in data: inputs["actions"] = data["actions"] # Pass the prompt (aka language instruction) to the model. # Keep this for your own dataset (but modify the key if the instruction is not # stored in "prompt"; the output dict always needs to have the key "prompt"). if "prompt" in data: inputs["prompt"] = data["prompt"] return inputs @dataclasses.dataclass(frozen=True) class LiberoOutputs(transforms.DataTransformFn): """ This class is used to convert outputs from the model back the the dataset specific format. It is used for inference only. For your own dataset, you can copy this class and modify the action dimension based on the comments below. """ def __call__(self, data: dict) -> dict: # Only return the first N actions -- since we padded actions above to fit the model action # dimension, we need to now parse out the correct number of actions in the return dict. # For Libero, we only return the first 7 actions (since the rest is padding). # For your own dataset, replace `7` with the action dimension of your dataset. return {"actions": np.asarray(data["actions"][:, :7])} ================================================ FILE: src/openpi/policies/policy.py ================================================ from collections.abc import Sequence import logging import pathlib import time from typing import Any, TypeAlias import flax import flax.traverse_util import jax import jax.numpy as jnp import numpy as np from openpi_client import base_policy as _base_policy import torch from typing_extensions import override from openpi import transforms as _transforms from openpi.models import model as _model from openpi.shared import array_typing as at from openpi.shared import nnx_utils BasePolicy: TypeAlias = _base_policy.BasePolicy class Policy(BasePolicy): def __init__( self, model: _model.BaseModel, *, rng: at.KeyArrayLike | None = None, transforms: Sequence[_transforms.DataTransformFn] = (), output_transforms: Sequence[_transforms.DataTransformFn] = (), sample_kwargs: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, pytorch_device: str = "cpu", is_pytorch: bool = False, ): """Initialize the Policy. Args: model: The model to use for action sampling. rng: Random number generator key for JAX models. Ignored for PyTorch models. transforms: Input data transformations to apply before inference. output_transforms: Output data transformations to apply after inference. sample_kwargs: Additional keyword arguments to pass to model.sample_actions. metadata: Additional metadata to store with the policy. pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda:0"). Only relevant when is_pytorch=True. is_pytorch: Whether the model is a PyTorch model. If False, assumes JAX model. """ self._model = model self._input_transform = _transforms.compose(transforms) self._output_transform = _transforms.compose(output_transforms) self._sample_kwargs = sample_kwargs or {} self._metadata = metadata or {} self._is_pytorch_model = is_pytorch self._pytorch_device = pytorch_device if self._is_pytorch_model: self._model = self._model.to(pytorch_device) self._model.eval() self._sample_actions = model.sample_actions else: # JAX model setup self._sample_actions = nnx_utils.module_jit(model.sample_actions) self._rng = rng or jax.random.key(0) @override def infer(self, obs: dict, *, noise: np.ndarray | None = None) -> dict: # type: ignore[misc] # Make a copy since transformations may modify the inputs in place. inputs = jax.tree.map(lambda x: x, obs) inputs = self._input_transform(inputs) if not self._is_pytorch_model: # Make a batch and convert to jax.Array. inputs = jax.tree.map(lambda x: jnp.asarray(x)[np.newaxis, ...], inputs) self._rng, sample_rng_or_pytorch_device = jax.random.split(self._rng) else: # Convert inputs to PyTorch tensors and move to correct device inputs = jax.tree.map(lambda x: torch.from_numpy(np.array(x)).to(self._pytorch_device)[None, ...], inputs) sample_rng_or_pytorch_device = self._pytorch_device # Prepare kwargs for sample_actions sample_kwargs = dict(self._sample_kwargs) if noise is not None: noise = torch.from_numpy(noise).to(self._pytorch_device) if self._is_pytorch_model else jnp.asarray(noise) if noise.ndim == 2: # If noise is (action_horizon, action_dim), add batch dimension noise = noise[None, ...] # Make it (1, action_horizon, action_dim) sample_kwargs["noise"] = noise observation = _model.Observation.from_dict(inputs) start_time = time.monotonic() outputs = { "state": inputs["state"], "actions": self._sample_actions(sample_rng_or_pytorch_device, observation, **sample_kwargs), } model_time = time.monotonic() - start_time if self._is_pytorch_model: outputs = jax.tree.map(lambda x: np.asarray(x[0, ...].detach().cpu()), outputs) else: outputs = jax.tree.map(lambda x: np.asarray(x[0, ...]), outputs) outputs = self._output_transform(outputs) outputs["policy_timing"] = { "infer_ms": model_time * 1000, } return outputs @property def metadata(self) -> dict[str, Any]: return self._metadata class PolicyRecorder(_base_policy.BasePolicy): """Records the policy's behavior to disk.""" def __init__(self, policy: _base_policy.BasePolicy, record_dir: str): self._policy = policy logging.info(f"Dumping policy records to: {record_dir}") self._record_dir = pathlib.Path(record_dir) self._record_dir.mkdir(parents=True, exist_ok=True) self._record_step = 0 @override def infer(self, obs: dict) -> dict: # type: ignore[misc] results = self._policy.infer(obs) data = {"inputs": obs, "outputs": results} data = flax.traverse_util.flatten_dict(data, sep="/") output_path = self._record_dir / f"step_{self._record_step}" self._record_step += 1 np.save(output_path, np.asarray(data)) return results ================================================ FILE: src/openpi/policies/policy_config.py ================================================ import logging import os import pathlib from typing import Any import jax.numpy as jnp import openpi.models.model as _model import openpi.policies.policy as _policy import openpi.shared.download as download from openpi.training import checkpoints as _checkpoints from openpi.training import config as _config import openpi.transforms as transforms def create_trained_policy( train_config: _config.TrainConfig, checkpoint_dir: pathlib.Path | str, *, repack_transforms: transforms.Group | None = None, sample_kwargs: dict[str, Any] | None = None, default_prompt: str | None = None, norm_stats: dict[str, transforms.NormStats] | None = None, pytorch_device: str | None = None, ) -> _policy.Policy: """Create a policy from a trained checkpoint. Args: train_config: The training config to use to create the model. checkpoint_dir: The directory to load the model from. repack_transforms: Optional transforms that will be applied before any other transforms. sample_kwargs: The kwargs to pass to the `sample_actions` method. If not provided, the default kwargs will be used. default_prompt: The default prompt to use for the policy. Will inject the prompt into the input data if it doesn't already exist. norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded from the checkpoint directory. pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda", "cuda:0"). If None and is_pytorch=True, will use "cuda" if available, otherwise "cpu". Note: The function automatically detects whether the model is PyTorch-based by checking for the presence of "model.safensors" in the checkpoint directory. """ repack_transforms = repack_transforms or transforms.Group() checkpoint_dir = download.maybe_download(str(checkpoint_dir)) # Check if this is a PyTorch model by looking for model.safetensors weight_path = os.path.join(checkpoint_dir, "model.safetensors") is_pytorch = os.path.exists(weight_path) logging.info("Loading model...") if is_pytorch: model = train_config.model.load_pytorch(train_config, weight_path) model.paligemma_with_expert.to_bfloat16_for_selected_params("bfloat16") else: model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16)) data_config = train_config.data.create(train_config.assets_dirs, train_config.model) if norm_stats is None: # We are loading the norm stats from the checkpoint instead of the config assets dir to make sure # that the policy is using the same normalization stats as the original training process. if data_config.asset_id is None: raise ValueError("Asset id is required to load norm stats.") norm_stats = _checkpoints.load_norm_stats(checkpoint_dir / "assets", data_config.asset_id) # Determine the device to use for PyTorch models if is_pytorch and pytorch_device is None: try: import torch pytorch_device = "cuda" if torch.cuda.is_available() else "cpu" except ImportError: pytorch_device = "cpu" return _policy.Policy( model, transforms=[ *repack_transforms.inputs, transforms.InjectDefaultPrompt(default_prompt), *data_config.data_transforms.inputs, transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm), *data_config.model_transforms.inputs, ], output_transforms=[ *data_config.model_transforms.outputs, transforms.Unnormalize(norm_stats, use_quantiles=data_config.use_quantile_norm), *data_config.data_transforms.outputs, *repack_transforms.outputs, ], sample_kwargs=sample_kwargs, metadata=train_config.policy_metadata, is_pytorch=is_pytorch, pytorch_device=pytorch_device if is_pytorch else None, ) ================================================ FILE: src/openpi/policies/policy_test.py ================================================ from openpi_client import action_chunk_broker import pytest from openpi.policies import aloha_policy from openpi.policies import policy_config as _policy_config from openpi.training import config as _config @pytest.mark.manual def test_infer(): config = _config.get_config("pi0_aloha_sim") policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim") example = aloha_policy.make_aloha_example() result = policy.infer(example) assert result["actions"].shape == (config.model.action_horizon, 14) @pytest.mark.manual def test_broker(): config = _config.get_config("pi0_aloha_sim") policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim") broker = action_chunk_broker.ActionChunkBroker( policy, # Only execute the first half of the chunk. action_horizon=config.model.action_horizon // 2, ) example = aloha_policy.make_aloha_example() for _ in range(config.model.action_horizon): outputs = broker.infer(example) assert outputs["actions"].shape == (14,) ================================================ FILE: src/openpi/py.typed ================================================ ================================================ FILE: src/openpi/serving/websocket_policy_server.py ================================================ import asyncio import http import logging import time import traceback from openpi_client import base_policy as _base_policy from openpi_client import msgpack_numpy import websockets.asyncio.server as _server import websockets.frames logger = logging.getLogger(__name__) class WebsocketPolicyServer: """Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation. Currently only implements the `load` and `infer` methods. """ def __init__( self, policy: _base_policy.BasePolicy, host: str = "0.0.0.0", port: int | None = None, metadata: dict | None = None, ) -> None: self._policy = policy self._host = host self._port = port self._metadata = metadata or {} logging.getLogger("websockets.server").setLevel(logging.INFO) def serve_forever(self) -> None: asyncio.run(self.run()) async def run(self): async with _server.serve( self._handler, self._host, self._port, compression=None, max_size=None, process_request=_health_check, ) as server: await server.serve_forever() async def _handler(self, websocket: _server.ServerConnection): logger.info(f"Connection from {websocket.remote_address} opened") packer = msgpack_numpy.Packer() await websocket.send(packer.pack(self._metadata)) prev_total_time = None while True: try: start_time = time.monotonic() obs = msgpack_numpy.unpackb(await websocket.recv()) infer_time = time.monotonic() action = self._policy.infer(obs) infer_time = time.monotonic() - infer_time action["server_timing"] = { "infer_ms": infer_time * 1000, } if prev_total_time is not None: # We can only record the last total time since we also want to include the send time. action["server_timing"]["prev_total_ms"] = prev_total_time * 1000 await websocket.send(packer.pack(action)) prev_total_time = time.monotonic() - start_time except websockets.ConnectionClosed: logger.info(f"Connection from {websocket.remote_address} closed") break except Exception: await websocket.send(traceback.format_exc()) await websocket.close( code=websockets.frames.CloseCode.INTERNAL_ERROR, reason="Internal server error. Traceback included in previous frame.", ) raise def _health_check(connection: _server.ServerConnection, request: _server.Request) -> _server.Response | None: if request.path == "/healthz": return connection.respond(http.HTTPStatus.OK, "OK\n") # Continue with the normal request handling. return None ================================================ FILE: src/openpi/shared/__init__.py ================================================ ================================================ FILE: src/openpi/shared/array_typing.py ================================================ import contextlib import functools as ft import inspect from typing import TypeAlias, TypeVar, cast import beartype import jax import jax._src.tree_util as private_tree_util import jax.core from jaxtyping import ArrayLike from jaxtyping import Bool # noqa: F401 from jaxtyping import DTypeLike # noqa: F401 from jaxtyping import Float from jaxtyping import Int # noqa: F401 from jaxtyping import Key # noqa: F401 from jaxtyping import Num # noqa: F401 from jaxtyping import PyTree from jaxtyping import Real # noqa: F401 from jaxtyping import UInt8 # noqa: F401 from jaxtyping import config from jaxtyping import jaxtyped import jaxtyping._decorator import torch # patch jaxtyping to handle https://github.com/patrick-kidger/jaxtyping/issues/277. # the problem is that custom PyTree nodes are sometimes initialized with arbitrary types (e.g., `jax.ShapeDtypeStruct`, # `jax.Sharding`, or even ) due to JAX tracing operations. this patch skips typechecking when the stack trace # contains `jax._src.tree_util`, which should only be the case during tree unflattening. _original_check_dataclass_annotations = jaxtyping._decorator._check_dataclass_annotations # noqa: SLF001 # Redefine Array to include both JAX arrays and PyTorch tensors Array = jax.Array | torch.Tensor def _check_dataclass_annotations(self, typechecker): if not any( frame.frame.f_globals.get("__name__") in {"jax._src.tree_util", "flax.nnx.transforms.compilation"} for frame in inspect.stack() ): return _original_check_dataclass_annotations(self, typechecker) return None jaxtyping._decorator._check_dataclass_annotations = _check_dataclass_annotations # noqa: SLF001 KeyArrayLike: TypeAlias = jax.typing.ArrayLike Params: TypeAlias = PyTree[Float[ArrayLike, "..."]] T = TypeVar("T") # runtime type-checking decorator def typecheck(t: T) -> T: return cast(T, ft.partial(jaxtyped, typechecker=beartype.beartype)(t)) @contextlib.contextmanager def disable_typechecking(): initial = config.jaxtyping_disable config.update("jaxtyping_disable", True) # noqa: FBT003 yield config.update("jaxtyping_disable", initial) def check_pytree_equality(*, expected: PyTree, got: PyTree, check_shapes: bool = False, check_dtypes: bool = False): """Checks that two PyTrees have the same structure and optionally checks shapes and dtypes. Creates a much nicer error message than if `jax.tree.map` is naively used on PyTrees with different structures. """ if errors := list(private_tree_util.equality_errors(expected, got)): raise ValueError( "PyTrees have different structure:\n" + ( "\n".join( f" - at keypath '{jax.tree_util.keystr(path)}': expected {thing1}, got {thing2}, so {explanation}.\n" for path, thing1, thing2, explanation in errors ) ) ) if check_shapes or check_dtypes: def check(kp, x, y): if check_shapes and x.shape != y.shape: raise ValueError(f"Shape mismatch at {jax.tree_util.keystr(kp)}: expected {x.shape}, got {y.shape}") if check_dtypes and x.dtype != y.dtype: raise ValueError(f"Dtype mismatch at {jax.tree_util.keystr(kp)}: expected {x.dtype}, got {y.dtype}") jax.tree_util.tree_map_with_path(check, expected, got) ================================================ FILE: src/openpi/shared/download.py ================================================ import concurrent.futures import datetime import logging import os import pathlib import re import shutil import stat import subprocess import time import urllib.parse import filelock import fsspec import fsspec.generic import tqdm_loggable.auto as tqdm # Environment variable to control cache directory path, ~/.cache/openpi will be used by default. _OPENPI_DATA_HOME = "OPENPI_DATA_HOME" DEFAULT_CACHE_DIR = "~/.cache/openpi" logger = logging.getLogger(__name__) def get_cache_dir() -> pathlib.Path: cache_dir = pathlib.Path(os.getenv(_OPENPI_DATA_HOME, DEFAULT_CACHE_DIR)).expanduser().resolve() cache_dir.mkdir(parents=True, exist_ok=True) _set_folder_permission(cache_dir) return cache_dir def maybe_download(url: str, *, force_download: bool = False, **kwargs) -> pathlib.Path: """Download a file or directory from a remote filesystem to the local cache, and return the local path. If the local file already exists, it will be returned directly. It is safe to call this function concurrently from multiple processes. See `get_cache_dir` for more details on the cache directory. Args: url: URL to the file to download. force_download: If True, the file will be downloaded even if it already exists in the cache. **kwargs: Additional arguments to pass to fsspec. Returns: Local path to the downloaded file or directory. That path is guaranteed to exist and is absolute. """ # Don't use fsspec to parse the url to avoid unnecessary connection to the remote filesystem. parsed = urllib.parse.urlparse(url) # Short circuit if this is a local path. if parsed.scheme == "": path = pathlib.Path(url) if not path.exists(): raise FileNotFoundError(f"File not found at {url}") return path.resolve() cache_dir = get_cache_dir() local_path = cache_dir / parsed.netloc / parsed.path.strip("/") local_path = local_path.resolve() # Check if the cache should be invalidated. invalidate_cache = False if local_path.exists(): if force_download or _should_invalidate_cache(cache_dir, local_path): invalidate_cache = True else: return local_path try: lock_path = local_path.with_suffix(".lock") with filelock.FileLock(lock_path): # Ensure consistent permissions for the lock file. _ensure_permissions(lock_path) # First, remove the existing cache if it is expired. if invalidate_cache: logger.info(f"Removing expired cached entry: {local_path}") if local_path.is_dir(): shutil.rmtree(local_path) else: local_path.unlink() if not local_path.exists(): # Download the data to a local cache. logger.info(f"Downloading {url} to {local_path}") scratch_path = local_path.with_suffix(".partial") # Route openpi-assets through gsutil to avoid gcsfs auth issues with this bucket. # All other gs:// URLs (e.g. big_vision) continue to use gcsfs as normal. if parsed.scheme == "gs" and parsed.netloc == "openpi-assets": _download_gsutil(url, scratch_path, **kwargs) else: _download_fsspec(url, scratch_path, **kwargs) shutil.move(scratch_path, local_path) _ensure_permissions(local_path) except PermissionError as e: msg = ( f"Local file permission error was encountered while downloading {url}. " f"Please try again after removing the cached data using: `rm -rf {local_path}*`" ) raise PermissionError(msg) from e return local_path def _download_gsutil(url: str, local_path: pathlib.Path, **kwargs) -> None: """Download a file or directory from GCS using gsutil if available, otherwise fall back to gcsfs.""" if shutil.which("gsutil") is None: logger.warning( "gsutil not found, falling back to gcsfs. This may fail if GCP credentials are not configured correctly." ) _download_fsspec(url, local_path, **kwargs) return local_path.mkdir(parents=True, exist_ok=True) subprocess.run( ["gsutil", "-m", "cp", "-r", f"{url}/*", str(local_path)], check=True, ) def _download_fsspec(url: str, local_path: pathlib.Path, **kwargs) -> None: """Download a file from a remote filesystem to the local cache, and return the local path.""" fs, _ = fsspec.core.url_to_fs(url, **kwargs) info = fs.info(url) # Folders are represented by 0-byte objects with a trailing forward slash. if is_dir := (info["type"] == "directory" or (info["size"] == 0 and info["name"].endswith("/"))): total_size = fs.du(url) else: total_size = info["size"] with tqdm.tqdm(total=total_size, unit="iB", unit_scale=True, unit_divisor=1024) as pbar: executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) future = executor.submit(fs.get, url, local_path, recursive=is_dir) while not future.done(): current_size = sum(f.stat().st_size for f in [*local_path.rglob("*"), local_path] if f.is_file()) pbar.update(current_size - pbar.n) time.sleep(1) pbar.update(total_size - pbar.n) def _set_permission(path: pathlib.Path, target_permission: int): """chmod requires executable permission to be set, so we skip if the permission is already match with the target.""" if path.stat().st_mode & target_permission == target_permission: logger.debug(f"Skipping {path} because it already has correct permissions") return path.chmod(target_permission) logger.debug(f"Set {path} to {target_permission}") def _set_folder_permission(folder_path: pathlib.Path) -> None: """Set folder permission to be read, write and searchable.""" _set_permission(folder_path, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) def _ensure_permissions(path: pathlib.Path) -> None: """Since we are sharing cache directory with containerized runtime as well as training script, we need to ensure that the cache directory has the correct permissions. """ def _setup_folder_permission_between_cache_dir_and_path(path: pathlib.Path) -> None: cache_dir = get_cache_dir() relative_path = path.relative_to(cache_dir) moving_path = cache_dir for part in relative_path.parts: _set_folder_permission(moving_path / part) moving_path = moving_path / part def _set_file_permission(file_path: pathlib.Path) -> None: """Set all files to be read & writable, if it is a script, keep it as a script.""" file_rw = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IWGRP | stat.S_IROTH | stat.S_IWOTH if file_path.stat().st_mode & 0o100: _set_permission(file_path, file_rw | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) else: _set_permission(file_path, file_rw) _setup_folder_permission_between_cache_dir_and_path(path) for root, dirs, files in os.walk(str(path)): root_path = pathlib.Path(root) for file in files: file_path = root_path / file _set_file_permission(file_path) for dir in dirs: dir_path = root_path / dir _set_folder_permission(dir_path) def _get_mtime(year: int, month: int, day: int) -> float: """Get the mtime of a given date at midnight UTC.""" date = datetime.datetime(year, month, day, tzinfo=datetime.UTC) return time.mktime(date.timetuple()) # Map of relative paths, defined as regular expressions, to expiration timestamps (mtime format). # Partial matching will be used from top to bottom and the first match will be chosen. # Cached entries will be retained only if they are newer than the expiration timestamp. _INVALIDATE_CACHE_DIRS: dict[re.Pattern, float] = { re.compile("openpi-assets/checkpoints/pi0_aloha_pen_uncap"): _get_mtime(2025, 2, 17), re.compile("openpi-assets/checkpoints/pi0_libero"): _get_mtime(2025, 2, 6), re.compile("openpi-assets/checkpoints/"): _get_mtime(2025, 2, 3), } def _should_invalidate_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool: """Invalidate the cache if it is expired. Return True if the cache was invalidated.""" assert local_path.exists(), f"File not found at {local_path}" relative_path = str(local_path.relative_to(cache_dir)) for pattern, expire_time in _INVALIDATE_CACHE_DIRS.items(): if pattern.match(relative_path): # Remove if not newer than the expiration timestamp. return local_path.stat().st_mtime <= expire_time return False ================================================ FILE: src/openpi/shared/download_test.py ================================================ import pathlib import pytest import openpi.shared.download as download @pytest.fixture(scope="session", autouse=True) def set_openpi_data_home(tmp_path_factory): temp_dir = tmp_path_factory.mktemp("openpi_data") with pytest.MonkeyPatch().context() as mp: mp.setenv("OPENPI_DATA_HOME", str(temp_dir)) yield def test_download_local(tmp_path: pathlib.Path): local_path = tmp_path / "local" local_path.touch() result = download.maybe_download(str(local_path)) assert result == local_path with pytest.raises(FileNotFoundError): download.maybe_download("bogus") def test_download_gs_dir(): remote_path = "gs://openpi-assets/testdata/random" local_path = download.maybe_download(remote_path) assert local_path.exists() new_local_path = download.maybe_download(remote_path) assert new_local_path == local_path def test_download_gs(): remote_path = "gs://openpi-assets/testdata/random/random_512kb.bin" local_path = download.maybe_download(remote_path) assert local_path.exists() new_local_path = download.maybe_download(remote_path) assert new_local_path == local_path def test_download_fsspec(): remote_path = "gs://big_vision/paligemma_tokenizer.model" local_path = download.maybe_download(remote_path, gs={"token": "anon"}) assert local_path.exists() new_local_path = download.maybe_download(remote_path, gs={"token": "anon"}) assert new_local_path == local_path ================================================ FILE: src/openpi/shared/image_tools.py ================================================ import functools import jax import jax.numpy as jnp import torch import torch.nn.functional as F # noqa: N812 import openpi.shared.array_typing as at @functools.partial(jax.jit, static_argnums=(1, 2, 3)) @at.typecheck def resize_with_pad( images: at.UInt8[at.Array, "*b h w c"] | at.Float[at.Array, "*b h w c"], height: int, width: int, method: jax.image.ResizeMethod = jax.image.ResizeMethod.LINEAR, ) -> at.UInt8[at.Array, "*b {height} {width} c"] | at.Float[at.Array, "*b {height} {width} c"]: """Replicates tf.image.resize_with_pad. Resizes an image to a target height and width without distortion by padding with black. If the image is float32, it must be in the range [-1, 1]. """ has_batch_dim = images.ndim == 4 if not has_batch_dim: images = images[None] # type: ignore cur_height, cur_width = images.shape[1:3] ratio = max(cur_width / width, cur_height / height) resized_height = int(cur_height / ratio) resized_width = int(cur_width / ratio) resized_images = jax.image.resize( images, (images.shape[0], resized_height, resized_width, images.shape[3]), method=method ) if images.dtype == jnp.uint8: # round from float back to uint8 resized_images = jnp.round(resized_images).clip(0, 255).astype(jnp.uint8) elif images.dtype == jnp.float32: resized_images = resized_images.clip(-1.0, 1.0) else: raise ValueError(f"Unsupported image dtype: {images.dtype}") pad_h0, remainder_h = divmod(height - resized_height, 2) pad_h1 = pad_h0 + remainder_h pad_w0, remainder_w = divmod(width - resized_width, 2) pad_w1 = pad_w0 + remainder_w padded_images = jnp.pad( resized_images, ((0, 0), (pad_h0, pad_h1), (pad_w0, pad_w1), (0, 0)), constant_values=0 if images.dtype == jnp.uint8 else -1.0, ) if not has_batch_dim: padded_images = padded_images[0] return padded_images def resize_with_pad_torch( images: torch.Tensor, height: int, width: int, mode: str = "bilinear", ) -> torch.Tensor: """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion by padding with black. If the image is float32, it must be in the range [-1, 1]. Args: images: Tensor of shape [*b, h, w, c] or [*b, c, h, w] height: Target height width: Target width mode: Interpolation mode ('bilinear', 'nearest', etc.) Returns: Resized and padded tensor with same shape format as input """ # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] if images.shape[-1] <= 4: # Assume channels-last format channels_last = True # Convert to channels-first for torch operations if images.dim() == 3: images = images.unsqueeze(0) # Add batch dimension images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] else: channels_last = False if images.dim() == 3: images = images.unsqueeze(0) # Add batch dimension batch_size, channels, cur_height, cur_width = images.shape # Calculate resize ratio ratio = max(cur_width / width, cur_height / height) resized_height = int(cur_height / ratio) resized_width = int(cur_width / ratio) # Resize resized_images = F.interpolate( images, size=(resized_height, resized_width), mode=mode, align_corners=False if mode == "bilinear" else None ) # Handle dtype-specific clipping if images.dtype == torch.uint8: resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) elif images.dtype == torch.float32: resized_images = resized_images.clamp(-1.0, 1.0) else: raise ValueError(f"Unsupported image dtype: {images.dtype}") # Calculate padding pad_h0, remainder_h = divmod(height - resized_height, 2) pad_h1 = pad_h0 + remainder_h pad_w0, remainder_w = divmod(width - resized_width, 2) pad_w1 = pad_w0 + remainder_w # Pad constant_value = 0 if images.dtype == torch.uint8 else -1.0 padded_images = F.pad( resized_images, (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom mode="constant", value=constant_value, ) # Convert back to original format if needed if channels_last: padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] if batch_size == 1 and images.shape[0] == 1: padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added return padded_images ================================================ FILE: src/openpi/shared/image_tools_test.py ================================================ import jax.numpy as jnp from openpi.shared import image_tools def test_resize_with_pad_shapes(): # Test case 1: Resize image with larger dimensions images = jnp.zeros((2, 10, 10, 3), dtype=jnp.uint8) # Input images of shape (batch_size, height, width, channels) height = 20 width = 20 resized_images = image_tools.resize_with_pad(images, height, width) assert resized_images.shape == (2, height, width, 3) assert jnp.all(resized_images == 0) # Test case 2: Resize image with smaller dimensions images = jnp.zeros((3, 30, 30, 3), dtype=jnp.uint8) height = 15 width = 15 resized_images = image_tools.resize_with_pad(images, height, width) assert resized_images.shape == (3, height, width, 3) assert jnp.all(resized_images == 0) # Test case 3: Resize image with the same dimensions images = jnp.zeros((1, 50, 50, 3), dtype=jnp.uint8) height = 50 width = 50 resized_images = image_tools.resize_with_pad(images, height, width) assert resized_images.shape == (1, height, width, 3) assert jnp.all(resized_images == 0) # Test case 3: Resize image with odd-numbered padding images = jnp.zeros((1, 256, 320, 3), dtype=jnp.uint8) height = 60 width = 80 resized_images = image_tools.resize_with_pad(images, height, width) assert resized_images.shape == (1, height, width, 3) assert jnp.all(resized_images == 0) ================================================ FILE: src/openpi/shared/nnx_utils.py ================================================ from collections.abc import Callable import dataclasses import functools import inspect import re from typing import Any, ParamSpec, TypeVar import flax.nnx as nnx import jax P = ParamSpec("P") R = TypeVar("R") def module_jit(meth: Callable[P, R], *jit_args, **jit_kwargs) -> Callable[P, R]: """A higher-order function to JIT-compile `nnx.Module` methods, freezing the module's state in the process. Why not `nnx.jit`? For some reason, naively applying `nnx.jit` to `nnx.Module` methods, bound or unbound, uses much more memory than necessary. I'm guessing it has something to do with the fact that it must keep track of module mutations. Also, `nnx.jit` has some inherent overhead compared to a standard `jax.jit`, since every call must traverse the NNX module graph. See https://github.com/google/flax/discussions/4224 for details. `module_jit` is an alternative that avoids these issues by freezing the module's state. The function returned by `module_jit` acts exactly like the original method, except that the state of the module is frozen to whatever it was when `module_jit` was called. Mutations to the module within `meth` are still allowed, but they will be discarded after the method call completes. """ if not (inspect.ismethod(meth) and isinstance(meth.__self__, nnx.Module)): raise ValueError("module_jit must only be used on bound methods of nnx.Modules.") graphdef, state = nnx.split(meth.__self__) def fun(state: nnx.State, *args: P.args, **kwargs: P.kwargs) -> R: module = nnx.merge(graphdef, state) return meth.__func__(module, *args, **kwargs) jitted_fn = jax.jit(fun, *jit_args, **jit_kwargs) @functools.wraps(meth) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: return jitted_fn(state, *args, **kwargs) return wrapper @dataclasses.dataclass(frozen=True) class PathRegex: """NNX Filter that matches paths using a regex. By default, paths are joined with a `/` separator. This can be overridden by setting the `sep` argument. """ pattern: str | re.Pattern sep: str = "/" def __post_init__(self): if not isinstance(self.pattern, re.Pattern): object.__setattr__(self, "pattern", re.compile(self.pattern)) def __call__(self, path: nnx.filterlib.PathParts, x: Any) -> bool: joined_path = self.sep.join(str(x) for x in path) assert isinstance(self.pattern, re.Pattern) return self.pattern.fullmatch(joined_path) is not None def state_map(state: nnx.State, filter: nnx.filterlib.Filter, fn: Callable[[Any], Any]) -> nnx.State: """Apply a function to the leaves of the state that match the filter.""" filtered_keys = set(state.filter(filter).flat_state()) return state.map(lambda k, v: fn(v) if k in filtered_keys else v) ================================================ FILE: src/openpi/shared/normalize.py ================================================ import json import pathlib import numpy as np import numpydantic import pydantic @pydantic.dataclasses.dataclass class NormStats: mean: numpydantic.NDArray std: numpydantic.NDArray q01: numpydantic.NDArray | None = None # 1st quantile q99: numpydantic.NDArray | None = None # 99th quantile class RunningStats: """Compute running statistics of a batch of vectors.""" def __init__(self): self._count = 0 self._mean = None self._mean_of_squares = None self._min = None self._max = None self._histograms = None self._bin_edges = None self._num_quantile_bins = 5000 # for computing quantiles on the fly def update(self, batch: np.ndarray) -> None: """ Update the running statistics with a batch of vectors. Args: vectors (np.ndarray): An array where all dimensions except the last are batch dimensions. """ batch = batch.reshape(-1, batch.shape[-1]) num_elements, vector_length = batch.shape if self._count == 0: self._mean = np.mean(batch, axis=0) self._mean_of_squares = np.mean(batch**2, axis=0) self._min = np.min(batch, axis=0) self._max = np.max(batch, axis=0) self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)] self._bin_edges = [ np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1) for i in range(vector_length) ] else: if vector_length != self._mean.size: raise ValueError("The length of new vectors does not match the initialized vector length.") new_max = np.max(batch, axis=0) new_min = np.min(batch, axis=0) max_changed = np.any(new_max > self._max) min_changed = np.any(new_min < self._min) self._max = np.maximum(self._max, new_max) self._min = np.minimum(self._min, new_min) if max_changed or min_changed: self._adjust_histograms() self._count += num_elements batch_mean = np.mean(batch, axis=0) batch_mean_of_squares = np.mean(batch**2, axis=0) # Update running mean and mean of squares. self._mean += (batch_mean - self._mean) * (num_elements / self._count) self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * (num_elements / self._count) self._update_histograms(batch) def get_statistics(self) -> NormStats: """ Compute and return the statistics of the vectors processed so far. Returns: dict: A dictionary containing the computed statistics. """ if self._count < 2: raise ValueError("Cannot compute statistics for less than 2 vectors.") variance = self._mean_of_squares - self._mean**2 stddev = np.sqrt(np.maximum(0, variance)) q01, q99 = self._compute_quantiles([0.01, 0.99]) return NormStats(mean=self._mean, std=stddev, q01=q01, q99=q99) def _adjust_histograms(self): """Adjust histograms when min or max changes.""" for i in range(len(self._histograms)): old_edges = self._bin_edges[i] new_edges = np.linspace(self._min[i], self._max[i], self._num_quantile_bins + 1) # Redistribute the existing histogram counts to the new bins new_hist, _ = np.histogram(old_edges[:-1], bins=new_edges, weights=self._histograms[i]) self._histograms[i] = new_hist self._bin_edges[i] = new_edges def _update_histograms(self, batch: np.ndarray) -> None: """Update histograms with new vectors.""" for i in range(batch.shape[1]): hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i]) self._histograms[i] += hist def _compute_quantiles(self, quantiles): """Compute quantiles based on histograms.""" results = [] for q in quantiles: target_count = q * self._count q_values = [] for hist, edges in zip(self._histograms, self._bin_edges, strict=True): cumsum = np.cumsum(hist) idx = np.searchsorted(cumsum, target_count) q_values.append(edges[idx]) results.append(np.array(q_values)) return results class _NormStatsDict(pydantic.BaseModel): norm_stats: dict[str, NormStats] def serialize_json(norm_stats: dict[str, NormStats]) -> str: """Serialize the running statistics to a JSON string.""" return _NormStatsDict(norm_stats=norm_stats).model_dump_json(indent=2) def deserialize_json(data: str) -> dict[str, NormStats]: """Deserialize the running statistics from a JSON string.""" return _NormStatsDict(**json.loads(data)).norm_stats def save(directory: pathlib.Path | str, norm_stats: dict[str, NormStats]) -> None: """Save the normalization stats to a directory.""" path = pathlib.Path(directory) / "norm_stats.json" path.parent.mkdir(parents=True, exist_ok=True) path.write_text(serialize_json(norm_stats)) def load(directory: pathlib.Path | str) -> dict[str, NormStats]: """Load the normalization stats from a directory.""" path = pathlib.Path(directory) / "norm_stats.json" if not path.exists(): raise FileNotFoundError(f"Norm stats file not found at: {path}") return deserialize_json(path.read_text()) ================================================ FILE: src/openpi/shared/normalize_test.py ================================================ import numpy as np import openpi.shared.normalize as normalize def test_normalize_update(): arr = np.arange(12).reshape(4, 3) # 4 vectors of length 3 stats = normalize.RunningStats() for i in range(len(arr)): stats.update(arr[i : i + 1]) # Update with one vector at a time results = stats.get_statistics() assert np.allclose(results.mean, np.mean(arr, axis=0)) assert np.allclose(results.std, np.std(arr, axis=0)) def test_serialize_deserialize(): stats = normalize.RunningStats() stats.update(np.arange(12).reshape(4, 3)) # 4 vectors of length 3 norm_stats = {"test": stats.get_statistics()} norm_stats2 = normalize.deserialize_json(normalize.serialize_json(norm_stats)) assert np.allclose(norm_stats["test"].mean, norm_stats2["test"].mean) assert np.allclose(norm_stats["test"].std, norm_stats2["test"].std) def test_multiple_batch_dimensions(): # Test with multiple batch dimensions: (2, 3, 4) where 4 is vector dimension batch_shape = (2, 3, 4) arr = np.random.rand(*batch_shape) stats = normalize.RunningStats() stats.update(arr) # Should handle (2, 3, 4) -> reshape to (6, 4) results = stats.get_statistics() # Flatten batch dimensions and compute expected stats flattened = arr.reshape(-1, arr.shape[-1]) # (6, 4) expected_mean = np.mean(flattened, axis=0) expected_std = np.std(flattened, axis=0) assert np.allclose(results.mean, expected_mean) assert np.allclose(results.std, expected_std) ================================================ FILE: src/openpi/training/checkpoints.py ================================================ from __future__ import annotations import asyncio import concurrent.futures as futures import dataclasses import logging from typing import Protocol from etils import epath import jax import orbax.checkpoint as ocp import orbax.checkpoint.future as future from openpi.shared import array_typing as at import openpi.shared.normalize as _normalize import openpi.training.data_loader as _data_loader import openpi.training.utils as training_utils def initialize_checkpoint_dir( checkpoint_dir: epath.Path | str, *, keep_period: int | None, overwrite: bool, resume: bool ) -> tuple[ocp.CheckpointManager, bool]: checkpoint_dir = epath.Path(checkpoint_dir).resolve() resuming = False if checkpoint_dir.exists(): if overwrite: checkpoint_dir.rmtree() checkpoint_dir.mkdir(parents=True, exist_ok=True) logging.info(f"Wiped checkpoint directory {checkpoint_dir}") elif resume: resuming = True else: raise FileExistsError( f"Checkpoint directory {checkpoint_dir} already exists. Use --overwrite or --resume " "to indicate how to handle it." ) checkpoint_dir.mkdir(parents=True, exist_ok=True) mngr = ocp.CheckpointManager( checkpoint_dir, item_handlers={ "assets": CallbackHandler(), "train_state": ocp.PyTreeCheckpointHandler(), "params": ocp.PyTreeCheckpointHandler(), }, options=ocp.CheckpointManagerOptions( max_to_keep=1, keep_period=keep_period, create=False, async_options=ocp.AsyncOptions(timeout_secs=7200), ), ) # Special case: the checkpoint directory exists and the user requests to resume training, but the training run did # not get to the first checkpoint saved. In this case, we don't actually want the train script to try and restore a # checkpoint, since it will fail. if resuming and tuple(mngr.all_steps()) in [(), (0,)]: logging.info("Checkpoint directory exists, but does not contain any checkpoints. Aborting resume.") resuming = False return mngr, resuming def save_state( checkpoint_manager: ocp.CheckpointManager, state: training_utils.TrainState, data_loader: _data_loader.DataLoader, step: int, ): def save_assets(directory: epath.Path): # Save the normalization stats. data_config = data_loader.data_config() norm_stats = data_config.norm_stats if norm_stats is not None and data_config.asset_id is not None: _normalize.save(directory / data_config.asset_id, norm_stats) # Split params that can be used for inference into a separate item. with at.disable_typechecking(): train_state, params = _split_params(state) items = { "assets": save_assets, "train_state": train_state, "params": {"params": params}, } checkpoint_manager.save(step, items) def restore_state( checkpoint_manager: ocp.CheckpointManager, state: training_utils.TrainState, data_loader: _data_loader.DataLoader, step: int | None = None, ) -> training_utils.TrainState: del data_loader with at.disable_typechecking(): # Split params that can be used for inference into a separate item. train_state, params = _split_params(state) restored = checkpoint_manager.restore( step, items={ "train_state": train_state, "params": {"params": params}, }, ) return _merge_params(restored["train_state"], restored["params"]) def load_norm_stats(assets_dir: epath.Path | str, asset_id: str) -> dict[str, _normalize.NormStats] | None: norm_stats_dir = epath.Path(assets_dir) / asset_id norm_stats = _normalize.load(norm_stats_dir) logging.info(f"Loaded norm stats from {norm_stats_dir}") return norm_stats class Callback(Protocol): def __call__(self, directory: epath.Path) -> None: ... class CallbackHandler(ocp.AsyncCheckpointHandler): """A CheckpointHandler for calling an arbitrary function asynchronously. Only for saving, not for restoring.""" def save(self, directory: epath.Path, args: CallbackSave): if jax.process_index() == 0: args.callback(directory) async def async_save(self, directory: epath.Path, args: CallbackSave) -> list[futures.Future]: return [future.CommitFutureAwaitingContractedSignals(asyncio.to_thread(self.save, directory, args))] def restore(self, *args, **kwargs): raise NotImplementedError("CallbackHandler does not support restore") @ocp.args.register_with_handler(CallbackHandler, for_save=True) @dataclasses.dataclass class CallbackSave(ocp.args.CheckpointArgs): callback: Callback @ocp.args.register_with_handler(CallbackHandler, for_restore=True) class CallbackRestore(ocp.args.CheckpointArgs): ... def _split_params(state: training_utils.TrainState) -> tuple[training_utils.TrainState, at.Params]: if state.ema_params is not None: params = state.ema_params train_state = dataclasses.replace(state, ema_params=None) else: params = state.params train_state = dataclasses.replace(state, params={}) return train_state, params def _merge_params(train_state: training_utils.TrainState, params: dict[str, at.Params]) -> training_utils.TrainState: # Revert the logic inside `_split_params`. Assumes that existence of `params` means that EMA params were used during the split. if train_state.params: return dataclasses.replace(train_state, ema_params=params["params"]) return dataclasses.replace(train_state, params=params["params"]) ================================================ FILE: src/openpi/training/config.py ================================================ """See _CONFIGS for the list of available configs.""" import abc from collections.abc import Sequence import dataclasses import difflib import logging import pathlib from typing import Any, Literal, Protocol, TypeAlias import etils.epath as epath import flax.nnx as nnx from typing_extensions import override import tyro import openpi.models.model as _model import openpi.models.pi0_config as pi0_config import openpi.models.pi0_fast as pi0_fast import openpi.models.tokenizer as _tokenizer import openpi.policies.aloha_policy as aloha_policy import openpi.policies.droid_policy as droid_policy import openpi.policies.libero_policy as libero_policy import openpi.shared.download as _download import openpi.shared.normalize as _normalize import openpi.training.droid_rlds_dataset as droid_rlds_dataset import openpi.training.misc.polaris_config as polaris_config import openpi.training.misc.roboarena_config as roboarena_config import openpi.training.optimizer as _optimizer import openpi.training.weight_loaders as weight_loaders import openpi.transforms as _transforms ModelType: TypeAlias = _model.ModelType # Work around a tyro issue with using nnx.filterlib.Filter directly. Filter: TypeAlias = nnx.filterlib.Filter @dataclasses.dataclass(frozen=True) class AssetsConfig: """Determines the location of assets (e.g., norm stats) that will be used to set up the data pipeline. These assets will be replicated inside the checkpoint under the `assets/asset_id` directory. This can be used to load assets from a different checkpoint (e.g., base model checkpoint) or some other centralized location. For example, to load the norm stats for the Trossen robot from the base model checkpoint during fine-tuning, use: ``` AssetsConfig( assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets", asset_id="trossen", ) ``` """ # Assets directory. If not provided, the config assets_dirs will be used. This is useful to load assets from # a different checkpoint (e.g., base model checkpoint) or some other centralized location. assets_dir: str | None = None # Asset id. If not provided, the repo id will be used. This allows users to reference assets that describe # different robot platforms. asset_id: str | None = None @dataclasses.dataclass(frozen=True) class DataConfig: # LeRobot repo id. If None, fake data will be created. repo_id: str | None = None # Directory within the assets directory containing the data assets. asset_id: str | None = None # Contains precomputed normalization stats. If None, normalization will not be performed. norm_stats: dict[str, _transforms.NormStats] | None = None # Used to adopt the inputs from a dataset specific format to a common format # which is expected by the data transforms. repack_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group) # Data transforms, typically include robot specific transformations. Will be applied # before the data is normalized. See `model.Observation` and `model.Actions` to learn about the # normalized data. data_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group) # Model specific transforms. Will be applied after the data is normalized. model_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group) # If true, will use quantile normalization. Otherwise, normal z-score normalization will be used. use_quantile_norm: bool = False # Names of keys that will be used by the data loader to generate the action sequence. The length of the # sequence is defined by the `action_horizon` field in the model config. This should be adjusted if your # LeRobot dataset is using different keys to represent the action. action_sequence_keys: Sequence[str] = ("actions",) # If true, will use the LeRobot dataset task to define the prompt. prompt_from_task: bool = False # Only used for RLDS data loader (ie currently only used for DROID). rlds_data_dir: str | None = None # Action space for DROID dataset. action_space: droid_rlds_dataset.DroidActionSpace | None = None # List of datasets to sample from: name, version, weight, and optionally filter_dict_path datasets: Sequence[droid_rlds_dataset.RLDSDataset] = () class GroupFactory(Protocol): def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group: """Create a group.""" @dataclasses.dataclass(frozen=True) class ModelTransformFactory(GroupFactory): """Creates model transforms for standard pi0 models.""" # If provided, will determine the default prompt that be used by the model. default_prompt: str | None = None def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group: match model_config.model_type: case _model.ModelType.PI0: return _transforms.Group( inputs=[ _transforms.InjectDefaultPrompt(self.default_prompt), _transforms.ResizeImages(224, 224), _transforms.TokenizePrompt( _tokenizer.PaligemmaTokenizer(model_config.max_token_len), ), _transforms.PadStatesAndActions(model_config.action_dim), ], ) case _model.ModelType.PI05: assert isinstance(model_config, pi0_config.Pi0Config) return _transforms.Group( inputs=[ _transforms.InjectDefaultPrompt(self.default_prompt), _transforms.ResizeImages(224, 224), _transforms.TokenizePrompt( _tokenizer.PaligemmaTokenizer(model_config.max_token_len), discrete_state_input=model_config.discrete_state_input, ), _transforms.PadStatesAndActions(model_config.action_dim), ], ) case _model.ModelType.PI0_FAST: tokenizer_cls = ( _tokenizer.FASTTokenizer if model_config.fast_model_tokenizer is None else model_config.fast_model_tokenizer ) tokenizer_kwargs = ( {} if model_config.fast_model_tokenizer_kwargs is None else model_config.fast_model_tokenizer_kwargs ) return _transforms.Group( inputs=[ _transforms.InjectDefaultPrompt(self.default_prompt), _transforms.ResizeImages(224, 224), _transforms.TokenizeFASTInputs( tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs), ), ], outputs=[ _transforms.ExtractFASTActions( tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs), action_horizon=model_config.action_horizon, action_dim=model_config.action_dim, ) ], ) @dataclasses.dataclass(frozen=True) class DataConfigFactory(abc.ABC): # The LeRobot repo id. repo_id: str = tyro.MISSING # Determines how the assets will be loaded. assets: AssetsConfig = dataclasses.field(default_factory=AssetsConfig) # Base config that will be updated by the factory. base_config: tyro.conf.Suppress[DataConfig | None] = None @abc.abstractmethod def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: """Create a data config.""" def create_base_config(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: repo_id = self.repo_id if self.repo_id is not tyro.MISSING else None asset_id = self.assets.asset_id or repo_id return dataclasses.replace( self.base_config or DataConfig(), repo_id=repo_id, asset_id=asset_id, norm_stats=self._load_norm_stats(epath.Path(self.assets.assets_dir or assets_dirs), asset_id), use_quantile_norm=model_config.model_type != ModelType.PI0, ) def _load_norm_stats(self, assets_dir: epath.Path, asset_id: str | None) -> dict[str, _transforms.NormStats] | None: if asset_id is None: return None try: data_assets_dir = str(assets_dir / asset_id) norm_stats = _normalize.load(_download.maybe_download(data_assets_dir)) logging.info(f"Loaded norm stats from {data_assets_dir}") return norm_stats except FileNotFoundError: logging.info(f"Norm stats not found in {data_assets_dir}, skipping.") return None @dataclasses.dataclass(frozen=True) class FakeDataConfig(DataConfigFactory): repo_id: str = "fake" @override def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: return DataConfig(repo_id=self.repo_id) @dataclasses.dataclass(frozen=True) class SimpleDataConfig(DataConfigFactory): # Factory for the data transforms. data_transforms: tyro.conf.Suppress[GroupFactory] = dataclasses.field(default_factory=GroupFactory) # Factory for the model transforms. model_transforms: tyro.conf.Suppress[GroupFactory] = dataclasses.field(default_factory=ModelTransformFactory) @override def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: return dataclasses.replace( self.create_base_config(assets_dirs, model_config), data_transforms=self.data_transforms(model_config), model_transforms=self.model_transforms(model_config), ) @dataclasses.dataclass(frozen=True) class LeRobotAlohaDataConfig(DataConfigFactory): # If true, will convert joint dimensions to deltas with respect to the current state before passing to the model. # Gripper dimensions will remain in absolute values. use_delta_joint_actions: bool = True # If provided, will be injected into the input data if the "prompt" key is not present. default_prompt: str | None = None # If true, this will convert the joint and gripper values from the standard Aloha space to # the space used by the pi internal runtime which was used to train the base model. People who # use standard Aloha data should set this to true. adapt_to_pi: bool = True # Repack transforms. repack_transforms: tyro.conf.Suppress[_transforms.Group] = dataclasses.field( default=_transforms.Group( inputs=[ _transforms.RepackTransform( { "images": {"cam_high": "observation.images.top"}, "state": "observation.state", "actions": "action", } ) ] ) ) # Action keys that will be used to read the action sequence from the dataset. action_sequence_keys: Sequence[str] = ("action",) @override def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: data_transforms = _transforms.Group( inputs=[aloha_policy.AlohaInputs(adapt_to_pi=self.adapt_to_pi)], outputs=[aloha_policy.AlohaOutputs(adapt_to_pi=self.adapt_to_pi)], ) if self.use_delta_joint_actions: delta_action_mask = _transforms.make_bool_mask(6, -1, 6, -1) data_transforms = data_transforms.push( inputs=[_transforms.DeltaActions(delta_action_mask)], outputs=[_transforms.AbsoluteActions(delta_action_mask)], ) model_transforms = ModelTransformFactory(default_prompt=self.default_prompt)(model_config) return dataclasses.replace( self.create_base_config(assets_dirs, model_config), repack_transforms=self.repack_transforms, data_transforms=data_transforms, model_transforms=model_transforms, action_sequence_keys=self.action_sequence_keys, ) @dataclasses.dataclass(frozen=True) class LeRobotLiberoDataConfig(DataConfigFactory): """ This config is used to configure transforms that are applied at various parts of the data pipeline. For your own dataset, you can copy this class and modify the transforms to match your dataset based on the comments below. """ extra_delta_transform: bool = False @override def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: # The repack transform is *only* applied to the data coming from the dataset, # and *not* during inference. We can use it to make inputs from the dataset look # as close as possible to those coming from the inference environment (e.g. match the keys). # Below, we match the keys in the dataset (which we defined in the data conversion script) to # the keys we use in our inference pipeline (defined in the inference script for libero). # For your own dataset, first figure out what keys your environment passes to the policy server # and then modify the mappings below so your dataset's keys get matched to those target keys. # The repack transform simply remaps key names here. repack_transform = _transforms.Group( inputs=[ _transforms.RepackTransform( { "observation/image": "image", "observation/wrist_image": "wrist_image", "observation/state": "state", "actions": "actions", "prompt": "prompt", } ) ] ) # The data transforms are applied to the data coming from the dataset *and* during inference. # Below, we define the transforms for data going into the model (``inputs``) and the transforms # for data coming out of the model (``outputs``) (the latter is only used during inference). # We defined these transforms in `libero_policy.py`. You can check the detailed comments there for # how to modify the transforms to match your dataset. Once you created your own transforms, you can # replace the transforms below with your own. data_transforms = _transforms.Group( inputs=[libero_policy.LiberoInputs(model_type=model_config.model_type)], outputs=[libero_policy.LiberoOutputs()], ) # One additional data transform: pi0 models are trained on delta actions (relative to the first # state in each action chunk). IF your data has ``absolute`` actions (e.g. target joint angles) # you can uncomment the following line to convert the actions to delta actions. The only exception # is for the gripper actions which are always absolute. # In the example below, we would apply the delta conversion to the first 6 actions (joints) and # leave the 7th action (gripper) unchanged, i.e. absolute. # In Libero, the raw actions in the dataset are already delta actions, so we *do not* need to # apply a separate delta conversion (that's why it's commented out). Choose whether to apply this # transform based on whether your dataset uses ``absolute`` or ``delta`` actions out of the box. # LIBERO already represents actions as deltas, but we have some old Pi0 checkpoints that are trained with this # extra delta transform. if self.extra_delta_transform: delta_action_mask = _transforms.make_bool_mask(6, -1) data_transforms = data_transforms.push( inputs=[_transforms.DeltaActions(delta_action_mask)], outputs=[_transforms.AbsoluteActions(delta_action_mask)], ) # Model transforms include things like tokenizing the prompt and action targets # You do not need to change anything here for your own dataset. model_transforms = ModelTransformFactory()(model_config) # We return all data transforms for training and inference. No need to change anything here. return dataclasses.replace( self.create_base_config(assets_dirs, model_config), repack_transforms=repack_transform, data_transforms=data_transforms, model_transforms=model_transforms, ) @dataclasses.dataclass(frozen=True) class RLDSDroidDataConfig(DataConfigFactory): """ Config for training on DROID, using RLDS data format (for efficient training on larger datasets). """ rlds_data_dir: str | None = None action_space: droid_rlds_dataset.DroidActionSpace | None = None # Filtering options. Can pass a path to a dictionary that maps episodes to timestep ranges # to tuples denoting ranges of time steps to keep (start, end). Episodes are uniquely identified with # f"{recording_folderpath}--{file_path}", both of which are present in the RLDS episode metadata. # List of datasets to sample from: name, version, weight, and optionally filter_dict_path datasets: Sequence[droid_rlds_dataset.RLDSDataset] = ( droid_rlds_dataset.RLDSDataset( name="droid", version="1.0.1", weight=1.0, filter_dict_path="gs://openpi-assets/droid/droid_sample_ranges_v1_0_1.json", ), ) @override def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: repack_transform = _transforms.Group( inputs=[ _transforms.RepackTransform( { "observation/exterior_image_1_left": "observation/image", "observation/wrist_image_left": "observation/wrist_image", "observation/joint_position": "observation/joint_position", "observation/gripper_position": "observation/gripper_position", "actions": "actions", "prompt": "prompt", } ) ] ) data_transforms = _transforms.Group( inputs=[droid_policy.DroidInputs(model_type=model_config.model_type)], outputs=[droid_policy.DroidOutputs()], ) if self.action_space == droid_rlds_dataset.DroidActionSpace.JOINT_POSITION: # Data loader returns absolute joint position actions -- convert to delta actions for training. delta_action_mask = _transforms.make_bool_mask(7, -1) data_transforms = data_transforms.push( inputs=[_transforms.DeltaActions(delta_action_mask)], outputs=[_transforms.AbsoluteActions(delta_action_mask)], ) model_transforms = ModelTransformFactory()(model_config) assert self.rlds_data_dir is not None, "Need to set rlds data dir for RLDS data loader." return dataclasses.replace( self.create_base_config(assets_dirs, model_config), repack_transforms=repack_transform, data_transforms=data_transforms, model_transforms=model_transforms, rlds_data_dir=self.rlds_data_dir, action_space=self.action_space, datasets=self.datasets, ) @dataclasses.dataclass(frozen=True) class LeRobotDROIDDataConfig(DataConfigFactory): """ Example data config for custom DROID dataset in LeRobot format. To convert your custom DROID dataset (<10s of hours) to LeRobot format, see examples/droid/convert_droid_data_to_lerobot.py """ @override def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig: repack_transform = _transforms.Group( inputs=[ _transforms.RepackTransform( { "observation/exterior_image_1_left": "exterior_image_1_left", "observation/exterior_image_2_left": "exterior_image_2_left", "observation/wrist_image_left": "wrist_image_left", "observation/joint_position": "joint_position", "observation/gripper_position": "gripper_position", "actions": "actions", "prompt": "prompt", } ) ] ) # We assume joint *velocity* actions, so we should *not* apply an additional delta transform. data_transforms = _transforms.Group( inputs=[droid_policy.DroidInputs(model_type=model_config.model_type)], outputs=[droid_policy.DroidOutputs()], ) model_transforms = ModelTransformFactory()(model_config) return dataclasses.replace( self.create_base_config(assets_dirs, model_config), repack_transforms=repack_transform, data_transforms=data_transforms, model_transforms=model_transforms, ) @dataclasses.dataclass(frozen=True) class TrainConfig: # Name of the config. Must be unique. Will be used to reference this config. name: tyro.conf.Suppress[str] # Project name. project_name: str = "openpi" # Experiment name. Will be used to name the metadata and checkpoint directories. exp_name: str = tyro.MISSING # Defines the model config. Some attributes (action_dim, action_horizon, and max_token_len) are shared by all models # -- see BaseModelConfig. Specific model implementations (e.g., Pi0Config) inherit from BaseModelConfig and may # define additional attributes. model: _model.BaseModelConfig = dataclasses.field(default_factory=pi0_config.Pi0Config) # A weight loader can optionally load (possibly partial) weights from disk after the model is initialized. weight_loader: weight_loaders.WeightLoader = dataclasses.field(default_factory=weight_loaders.NoOpWeightLoader) # Optional path to a PyTorch checkpoint to load weights from. pytorch_weight_path: str | None = None # Precision for PyTorch training. pytorch_training_precision: Literal["bfloat16", "float32"] = "bfloat16" lr_schedule: _optimizer.LRScheduleConfig = dataclasses.field(default_factory=_optimizer.CosineDecaySchedule) optimizer: _optimizer.OptimizerConfig = dataclasses.field(default_factory=_optimizer.AdamW) ema_decay: float | None = 0.99 # Specifies which weights should be frozen. freeze_filter: tyro.conf.Suppress[Filter] = dataclasses.field(default_factory=nnx.Nothing) # Determines the data to be trained on. data: DataConfigFactory = dataclasses.field(default_factory=FakeDataConfig) # Base directory for config assets (e.g., norm stats). assets_base_dir: str = "./assets" # Base directory for checkpoints. checkpoint_base_dir: str = "./checkpoints" # Random seed that will be used by random generators during training. seed: int = 42 # Global batch size. batch_size: int = 32 # Number of workers to use for the data loader. Increasing this number will speed up data loading but # will increase memory and CPU usage. num_workers: int = 2 # Number of train steps (batches) to run. num_train_steps: int = 30_000 # How often (in steps) to log training metrics. log_interval: int = 100 # How often (in steps) to save checkpoints. save_interval: int = 1000 # If set, any existing checkpoints matching step % keep_period == 0 will not be deleted. keep_period: int | None = 5000 # If true, will overwrite the checkpoint directory if it already exists. overwrite: bool = False # If true, will resume training from the last checkpoint. resume: bool = False # If true, will enable wandb logging. wandb_enabled: bool = True # Used to pass metadata to the policy server. policy_metadata: dict[str, Any] | None = None # If the value is greater than 1, FSDP will be enabled and shard across number of specified devices; overall # device memory will be reduced but training could potentially be slower. # eg. if total device is 4 and fsdp devices is 2; then the model will shard to 2 devices and run # data parallel between 2 groups of devices. fsdp_devices: int = 1 @property def assets_dirs(self) -> pathlib.Path: """Get the assets directory for this config.""" return (pathlib.Path(self.assets_base_dir) / self.name).resolve() @property def checkpoint_dir(self) -> pathlib.Path: """Get the checkpoint directory for this config.""" if not self.exp_name: raise ValueError("--exp_name must be set") return (pathlib.Path(self.checkpoint_base_dir) / self.name / self.exp_name).resolve() @property def trainable_filter(self) -> nnx.filterlib.Filter: """Get the filter for the trainable parameters.""" return nnx.All(nnx.Param, nnx.Not(self.freeze_filter)) def __post_init__(self) -> None: if self.resume and self.overwrite: raise ValueError("Cannot resume and overwrite at the same time.") # Use `get_config` if you need to get a config by name in your code. _CONFIGS = [ # # Inference Aloha configs. # TrainConfig( name="pi0_aloha", model=pi0_config.Pi0Config(), data=LeRobotAlohaDataConfig( assets=AssetsConfig(asset_id="trossen"), ), policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]}, ), TrainConfig( name="pi05_aloha", model=pi0_config.Pi0Config(pi05=True), data=LeRobotAlohaDataConfig( assets=AssetsConfig(asset_id="trossen"), ), policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]}, ), TrainConfig( name="pi0_aloha_towel", model=pi0_config.Pi0Config(), data=LeRobotAlohaDataConfig( assets=AssetsConfig(asset_id="trossen"), default_prompt="fold the towel", ), policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]}, ), TrainConfig( name="pi0_aloha_tupperware", model=pi0_config.Pi0Config(), data=LeRobotAlohaDataConfig( assets=AssetsConfig(asset_id="trossen"), default_prompt="open the tupperware and put the food on the plate", ), policy_metadata={"reset_pose": [0, -1.5, 1.5, 0, 0, 0]}, ), # # Inference DROID configs. # TrainConfig( name="pi0_droid", model=pi0_config.Pi0Config(action_horizon=10), data=SimpleDataConfig( assets=AssetsConfig(asset_id="droid"), data_transforms=lambda model: _transforms.Group( inputs=[droid_policy.DroidInputs(model_type=ModelType.PI0)], outputs=[droid_policy.DroidOutputs()], ), base_config=DataConfig( prompt_from_task=True, ), ), ), TrainConfig( name="pi0_fast_droid", model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=10), data=SimpleDataConfig( assets=AssetsConfig(asset_id="droid"), data_transforms=lambda model: _transforms.Group( inputs=[droid_policy.DroidInputs(model_type=ModelType.PI0_FAST)], outputs=[droid_policy.DroidOutputs()], ), base_config=DataConfig( prompt_from_task=True, ), ), ), TrainConfig( name="pi05_droid", model=pi0_config.Pi0Config(action_horizon=15, pi05=True), data=SimpleDataConfig( assets=AssetsConfig(asset_id="droid"), data_transforms=lambda model: _transforms.Group( inputs=[droid_policy.DroidInputs(model_type=ModelType.PI05)], outputs=[droid_policy.DroidOutputs()], ), base_config=DataConfig( prompt_from_task=True, ), ), ), # # Fine-tuning Libero configs. # # These train configs define the hyperparameters for fine-tuning the base model on your own dataset. # They are used to define key elements like the dataset you are training on, the base checkpoint you # are using, and other hyperparameters like how many training steps to run or what learning rate to use. # For your own dataset, you can copy this class and modify the dataset name, and data transforms based on # the comments below. TrainConfig( # Change the name to reflect your model and dataset. name="pi0_libero", # Here you define the model config -- In this example we use pi0 as the model # architecture and perform *full* finetuning. in the examples below we show how to modify # this to perform *low-memory* (LORA) finetuning and use pi0-FAST as an alternative architecture. model=pi0_config.Pi0Config(), # Here you define the dataset you are training on. In this example we use the Libero # dataset. For your own dataset, you can change the repo_id to point to your dataset. # Also modify the DataConfig to use the new config you made for your dataset above. data=LeRobotLiberoDataConfig( repo_id="physical-intelligence/libero", base_config=DataConfig( # This flag determines whether we load the prompt (i.e. the task instruction) from the # ``task`` field in the LeRobot dataset. If set to True, the prompt will show up in # a field called ``prompt`` in the input dict. The recommended setting is True. prompt_from_task=True, ), extra_delta_transform=True, ), # Here you define which pre-trained checkpoint you want to load to initialize the model. # This should match the model config you chose above -- i.e. in this case we use the pi0 base model. weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"), # Below you can define other hyperparameters like the learning rate, number of training steps, etc. # Check the base TrainConfig class for a full list of available hyperparameters. num_train_steps=30_000, ), TrainConfig( name="pi0_libero_low_mem_finetune", # Here is an example of loading a pi0 model for LoRA fine-tuning. model=pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"), data=LeRobotLiberoDataConfig( repo_id="physical-intelligence/libero", base_config=DataConfig(prompt_from_task=True), extra_delta_transform=True, ), weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"), num_train_steps=30_000, # The freeze filter defines which parameters should be frozen during training. # We have a convenience function in the model config that returns the default freeze filter # for the given model config for LoRA finetuning. Just make sure it matches the model config # you chose above. freeze_filter=pi0_config.Pi0Config( paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora" ).get_freeze_filter(), # Turn off EMA for LoRA finetuning. ema_decay=None, ), TrainConfig( name="pi0_fast_libero", # Here is an example of loading a pi0-FAST model for full finetuning. # Modify action_dim and action_horizon to match your dataset (action horizon is equal to # the desired action chunk length). # The max_token_len is the maximum number of (non-image) tokens the model can handle. # This includes the tokenized prompt, proprioceptive state, and (FAST-tokenized) action tokens. # Choosing this value too small may chop off tokens at the end of your sequence (the code will throw # a warning), while choosing it too large will waste memory (since we pad each batch element to the # max_token_len). A good rule of thumb is to use approx 180 for single-arm robots, and approx 250 for # two-arm robots. Generally, err on the lower side here first, and potentially increase the value if # you see many warnings being thrown during training. model=pi0_fast.Pi0FASTConfig(action_dim=7, action_horizon=10, max_token_len=180), data=LeRobotLiberoDataConfig( repo_id="physical-intelligence/libero", base_config=DataConfig(prompt_from_task=True), extra_delta_transform=True, ), # Note that we load the pi0-FAST base model checkpoint here. weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"), num_train_steps=30_000, ), TrainConfig( name="pi0_fast_libero_low_mem_finetune", # Here is an example of loading a pi0-FAST model for LoRA finetuning. # For setting action_dim, action_horizon, and max_token_len, see the comments above. model=pi0_fast.Pi0FASTConfig( action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora" ), data=LeRobotLiberoDataConfig( repo_id="physical-intelligence/libero", base_config=DataConfig(prompt_from_task=True), extra_delta_transform=True, ), weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"), num_train_steps=30_000, # Again, make sure to match the model config above when extracting the freeze filter # that specifies which parameters should be frozen during LoRA finetuning. freeze_filter=pi0_fast.Pi0FASTConfig( action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora" ).get_freeze_filter(), # Turn off EMA for LoRA finetuning. ema_decay=None, ), TrainConfig( name="pi05_libero", model=pi0_config.Pi0Config(pi05=True, action_horizon=10, discrete_state_input=False), data=LeRobotLiberoDataConfig( repo_id="physical-intelligence/libero", base_config=DataConfig(prompt_from_task=True), extra_delta_transform=False, ), batch_size=256, lr_schedule=_optimizer.CosineDecaySchedule( warmup_steps=10_000, peak_lr=5e-5, decay_steps=1_000_000, decay_lr=5e-5, ), optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), ema_decay=0.999, weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"), pytorch_weight_path="/path/to/your/pytorch_weight_path", num_train_steps=30_000, ), # # Fine-tuning Aloha configs. # # This is a test config that is used to illustate how train on a custom LeRobot dataset. # For instructions on how to convert and train on your own Aloha dataset see examples/aloha_real/README.md TrainConfig( name="pi0_aloha_pen_uncap", model=pi0_config.Pi0Config(), data=LeRobotAlohaDataConfig( repo_id="physical-intelligence/aloha_pen_uncap_diverse", assets=AssetsConfig( assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets", asset_id="trossen", ), default_prompt="uncap the pen", repack_transforms=_transforms.Group( inputs=[ _transforms.RepackTransform( { "images": { "cam_high": "observation.images.cam_high", "cam_left_wrist": "observation.images.cam_left_wrist", "cam_right_wrist": "observation.images.cam_right_wrist", }, "state": "observation.state", "actions": "action", } ) ] ), ), weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"), num_train_steps=20_000, ), TrainConfig( name="pi05_aloha_pen_uncap", model=pi0_config.Pi0Config(pi05=True), data=LeRobotAlohaDataConfig( repo_id="physical-intelligence/aloha_pen_uncap_diverse", assets=AssetsConfig( assets_dir="gs://openpi-assets/checkpoints/pi05_base/assets", asset_id="trossen", ), default_prompt="uncap the pen", repack_transforms=_transforms.Group( inputs=[ _transforms.RepackTransform( { "images": { "cam_high": "observation.images.cam_high", "cam_left_wrist": "observation.images.cam_left_wrist", "cam_right_wrist": "observation.images.cam_right_wrist", }, "state": "observation.state", "actions": "action", } ) ] ), ), weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"), num_train_steps=20_000, batch_size=64, ), # # Fine-tuning DROID configs. # TrainConfig( # This config is for fine-tuning pi0-FAST-base on the *full* DROID dataset. # We use RLDS data loading to make training on this large dataset tractable. # For fine-tuning on your own DROID dataset, see below. name="pi0_fast_full_droid_finetune", model=pi0_fast.Pi0FASTConfig( action_dim=8, action_horizon=16, max_token_len=180, ), data=RLDSDroidDataConfig( repo_id="droid", # Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory). rlds_data_dir="", action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION, ), weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"), lr_schedule=_optimizer.CosineDecaySchedule( warmup_steps=1_000, peak_lr=5e-5, decay_steps=1_000_000, decay_lr=5e-5, ), num_train_steps=100_000, # 100k steps should be sufficient, takes ~2 days on 8x H100s batch_size=256, log_interval=100, save_interval=5000, keep_period=20_000, num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally ), TrainConfig( # This config is for fine-tuning pi05 on the *full* DROID dataset. # We use RLDS data loading to make training on this large dataset tractable. # For fine-tuning on your own DROID dataset, see below. name="pi05_full_droid_finetune", model=pi0_config.Pi0Config( pi05=True, action_dim=32, action_horizon=16, ), data=RLDSDroidDataConfig( repo_id="droid", # Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory). rlds_data_dir="/mnt/pi-data/kevin", action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION, assets=AssetsConfig( assets_dir="gs://openpi-assets/checkpoints/pi05_base/assets/", asset_id="droid", ), ), weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"), lr_schedule=_optimizer.CosineDecaySchedule( warmup_steps=1_000, peak_lr=5e-5, decay_steps=1_000_000, decay_lr=5e-5, ), num_train_steps=100_000, batch_size=256, log_interval=100, save_interval=5000, keep_period=10_000, num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally ), TrainConfig( # This config is for fine-tuning pi05-DROID on a custom (smaller) DROID dataset. # Here, we use LeRobot data format (like for all other fine-tuning examples) # To convert your custom DROID dataset (<10s of hours) to LeRobot format, see examples/droid/convert_droid_data_to_lerobot.py name="pi05_droid_finetune", model=pi0_config.Pi0Config( pi05=True, action_dim=32, # pi05 is trained with 32-dim actions action_horizon=16, ), data=LeRobotDROIDDataConfig( # Replace with your custom DROID LeRobot dataset repo id. repo_id="your_hf_username/my_droid_dataset", base_config=DataConfig(prompt_from_task=True), assets=AssetsConfig( # Important: reuse the original DROID norm stats during fine-tuning! assets_dir="gs://openpi-assets/checkpoints/pi05_droid/assets", asset_id="droid", ), ), weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_droid/params"), num_train_steps=20_000, batch_size=32, ), # # ALOHA Sim configs. This config is used to demonstrate how to train on a simple simulated environment. # TrainConfig( name="pi0_aloha_sim", model=pi0_config.Pi0Config(), data=LeRobotAlohaDataConfig( repo_id="lerobot/aloha_sim_transfer_cube_human", default_prompt="Transfer cube", use_delta_joint_actions=False, ), weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"), num_train_steps=20_000, ), # # Debugging configs. # TrainConfig( name="debug", data=FakeDataConfig(), batch_size=2, model=pi0_config.Pi0Config(paligemma_variant="dummy", action_expert_variant="dummy"), save_interval=100, overwrite=True, exp_name="debug", num_train_steps=10, wandb_enabled=False, ), TrainConfig( name="debug_restore", data=FakeDataConfig(), batch_size=2, model=pi0_config.Pi0Config(paligemma_variant="dummy", action_expert_variant="dummy"), weight_loader=weight_loaders.CheckpointWeightLoader("./checkpoints/debug/debug/9/params"), overwrite=True, exp_name="debug", num_train_steps=10, wandb_enabled=False, ), TrainConfig( name="debug_pi05", model=pi0_config.Pi0Config(pi05=True, paligemma_variant="dummy", action_expert_variant="dummy"), data=FakeDataConfig(), batch_size=2, num_train_steps=10, overwrite=True, exp_name="debug_pi05", wandb_enabled=False, ), # RoboArena & PolaRiS configs. *roboarena_config.get_roboarena_configs(), *polaris_config.get_polaris_configs(), ] if len({config.name for config in _CONFIGS}) != len(_CONFIGS): raise ValueError("Config names must be unique.") _CONFIGS_DICT = {config.name: config for config in _CONFIGS} def cli() -> TrainConfig: return tyro.extras.overridable_config_cli({k: (k, v) for k, v in _CONFIGS_DICT.items()}) def get_config(config_name: str) -> TrainConfig: """Get a config by name.""" if config_name not in _CONFIGS_DICT: closest = difflib.get_close_matches(config_name, _CONFIGS_DICT.keys(), n=1, cutoff=0.0) closest_str = f" Did you mean '{closest[0]}'? " if closest else "" raise ValueError(f"Config '{config_name}' not found.{closest_str}") return _CONFIGS_DICT[config_name] ================================================ FILE: src/openpi/training/data_loader.py ================================================ from collections.abc import Iterator, Sequence import logging import multiprocessing import os import typing from typing import Literal, Protocol, SupportsIndex, TypeVar import jax import jax.numpy as jnp import lerobot.common.datasets.lerobot_dataset as lerobot_dataset import numpy as np import torch import openpi.models.model as _model import openpi.training.config as _config from openpi.training.droid_rlds_dataset import DroidRldsDataset import openpi.transforms as _transforms T_co = TypeVar("T_co", covariant=True) class Dataset(Protocol[T_co]): """Interface for a dataset with random access.""" def __getitem__(self, index: SupportsIndex) -> T_co: raise NotImplementedError("Subclasses of Dataset should implement __getitem__.") def __len__(self) -> int: raise NotImplementedError("Subclasses of Dataset should implement __len__.") class IterableDataset(Protocol[T_co]): """Interface for an iterable dataset.""" def __iter__(self) -> Iterator[T_co]: raise NotImplementedError("Subclasses of IterableDataset should implement __iter__.") def __len__(self) -> int: raise NotImplementedError("Subclasses of Dataset should implement __len__.") class DataLoader(Protocol[T_co]): """Interface for a data loader.""" def data_config(self) -> _config.DataConfig: """Get the data config for this data loader.""" raise NotImplementedError("Subclasses of DataLoader should implement data_config.") def __iter__(self) -> Iterator[T_co]: raise NotImplementedError("Subclasses of DataLoader should implement __iter__.") class TransformedDataset(Dataset[T_co]): def __init__(self, dataset: Dataset, transforms: Sequence[_transforms.DataTransformFn]): self._dataset = dataset self._transform = _transforms.compose(transforms) def __getitem__(self, index: SupportsIndex) -> T_co: return self._transform(self._dataset[index]) def __len__(self) -> int: return len(self._dataset) class IterableTransformedDataset(IterableDataset[T_co]): def __init__( self, dataset: IterableDataset, transforms: Sequence[_transforms.DataTransformFn], *, is_batched: bool = False, ): self._dataset = dataset self._transform = _transforms.compose(transforms) self._is_batched = is_batched def __iter__(self): for sample in self._dataset: if self._is_batched: # Transforms are designed to be applied to individual samples. So we need to split the batch into # individual samples and apply the transform to each sample individually. batch_size = next(v.shape[0] for v in sample.values()) # Split batch into individual samples using tree_map individual_samples = [jax.tree.map(lambda x: x[i], sample) for i in range(batch_size)] # noqa: B023 # Transform each sample transformed = [self._transform(s) for s in individual_samples] # Recombine batch with tree_map yield jax.tree.map(lambda *x: np.stack(x, axis=0), *transformed) else: yield self._transform(sample) def __len__(self) -> int: return len(self._dataset) class FakeDataset(Dataset): def __init__(self, model_config: _model.BaseModelConfig, num_samples: int): self._num_samples = num_samples self._observation_spec, self._action_spec = model_config.inputs_spec() def __getitem__(self, index: SupportsIndex) -> dict: rng = jax.random.key(index.__index__()) def make_from_spec(spec: jax.ShapeDtypeStruct): nonlocal rng rng, data_rng = jax.random.split(rng) # Remove the batch dimension. shape = spec.shape[1:] if spec.dtype == jnp.float32: return jax.random.uniform(data_rng, shape=shape, minval=-1.0, maxval=1.0) if spec.dtype == jnp.int32: return jax.random.randint(data_rng, shape=shape, minval=0, maxval=2048) return jnp.zeros(shape=shape, dtype=spec.dtype) observation = jax.tree.map(make_from_spec, self._observation_spec) action = jax.tree.map(make_from_spec, self._action_spec) return { **observation.to_dict(), "actions": action, } def __len__(self) -> int: return self._num_samples def create_torch_dataset( data_config: _config.DataConfig, action_horizon: int, model_config: _model.BaseModelConfig ) -> Dataset: """Create a dataset for training.""" repo_id = data_config.repo_id if repo_id is None: raise ValueError("Repo ID is not set. Cannot create dataset.") if repo_id == "fake": return FakeDataset(model_config, num_samples=1024) dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id) dataset = lerobot_dataset.LeRobotDataset( data_config.repo_id, delta_timestamps={ key: [t / dataset_meta.fps for t in range(action_horizon)] for key in data_config.action_sequence_keys }, ) if data_config.prompt_from_task: dataset = TransformedDataset(dataset, [_transforms.PromptFromLeRobotTask(dataset_meta.tasks)]) return dataset def create_rlds_dataset( data_config: _config.DataConfig, action_horizon: int, batch_size: int, *, shuffle: bool = False, ) -> Dataset: # At the moment, we only support DROID for RLDS datasets. return DroidRldsDataset( data_dir=data_config.rlds_data_dir, batch_size=batch_size, shuffle=shuffle, action_chunk_size=action_horizon, action_space=data_config.action_space, datasets=data_config.datasets, ) def transform_dataset(dataset: Dataset, data_config: _config.DataConfig, *, skip_norm_stats: bool = False) -> Dataset: """Transform the dataset by applying the data transforms.""" norm_stats = {} if data_config.repo_id != "fake" and not skip_norm_stats: if data_config.norm_stats is None: raise ValueError( "Normalization stats not found. " "Make sure to run `scripts/compute_norm_stats.py --config-name=`." ) norm_stats = data_config.norm_stats return TransformedDataset( dataset, [ *data_config.repack_transforms.inputs, *data_config.data_transforms.inputs, _transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm), *data_config.model_transforms.inputs, ], ) def transform_iterable_dataset( dataset: IterableDataset, data_config: _config.DataConfig, *, skip_norm_stats: bool = False, is_batched: bool = False, ) -> IterableDataset: """Transform the dataset by applying the data transforms.""" norm_stats = {} if data_config.repo_id != "fake" and not skip_norm_stats: if data_config.norm_stats is None: raise ValueError( "Normalization stats not found. " "Make sure to run `scripts/compute_norm_stats.py --config-name=`." ) norm_stats = data_config.norm_stats return IterableTransformedDataset( dataset, [ *data_config.repack_transforms.inputs, *data_config.data_transforms.inputs, _transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm), *data_config.model_transforms.inputs, ], is_batched=is_batched, ) def create_data_loader( config: _config.TrainConfig, *, sharding: jax.sharding.Sharding | None = None, shuffle: bool = False, num_batches: int | None = None, skip_norm_stats: bool = False, framework: Literal["jax", "pytorch"] = "jax", ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: """Create a data loader for training. Args: config: The training configuration. sharding: The sharding to use for the data loader (JAX only). shuffle: Whether to shuffle the data. num_batches: Determines the number of batches to return. skip_norm_stats: Whether to skip data normalization. framework: The framework to use ("jax" or "pytorch"). """ data_config = config.data.create(config.assets_dirs, config.model) logging.info(f"data_config: {data_config}") if data_config.rlds_data_dir is not None: return create_rlds_data_loader( data_config, action_horizon=config.model.action_horizon, batch_size=config.batch_size, sharding=sharding, shuffle=shuffle, num_batches=num_batches, skip_norm_stats=skip_norm_stats, framework=framework, ) return create_torch_data_loader( data_config, model_config=config.model, action_horizon=config.model.action_horizon, batch_size=config.batch_size, sharding=sharding, shuffle=shuffle, num_batches=num_batches, num_workers=config.num_workers, seed=config.seed, skip_norm_stats=skip_norm_stats, framework=framework, ) def create_torch_data_loader( data_config: _config.DataConfig, model_config: _model.BaseModelConfig, action_horizon: int, batch_size: int, *, sharding: jax.sharding.Sharding | None = None, skip_norm_stats: bool = False, shuffle: bool = False, num_batches: int | None = None, num_workers: int = 0, seed: int = 0, framework: str = "jax", ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: """Create a data loader for training. Args: data_config: The data configuration. action_horizon: The action horizon. batch_size: The batch size. sharding: The sharding to use for the data loader. If None, the data loader will use a single device sharding. skip_norm_stats: Whether to skip data normalization. shuffle: Whether to shuffle the data. num_batches: Determines the number of batches to return. If the number exceeds the number of batches in the dataset, the data loader will loop over the dataset. If not provided, will iterate over the dataset indefinitely. num_workers: The number of worker processes to use. If zero, the data loader will execute in the main process. seed: The seed to use for shuffling the data. """ dataset = create_torch_dataset(data_config, action_horizon, model_config) dataset = transform_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats) # Use TorchDataLoader for both frameworks # For PyTorch DDP, create DistributedSampler and divide batch size by world size # For JAX, divide by process count sampler = None if framework == "pytorch": if torch.distributed.is_initialized(): sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=torch.distributed.get_world_size(), rank=torch.distributed.get_rank(), shuffle=shuffle, drop_last=True, ) local_batch_size = batch_size // torch.distributed.get_world_size() else: local_batch_size = batch_size else: local_batch_size = batch_size // jax.process_count() logging.info(f"local_batch_size: {local_batch_size}") data_loader = TorchDataLoader( dataset, local_batch_size=local_batch_size, sharding=None if framework == "pytorch" else sharding, shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler sampler=sampler, num_batches=num_batches, num_workers=num_workers, seed=seed, framework=framework, ) return DataLoaderImpl(data_config, data_loader) def create_rlds_data_loader( data_config: _config.DataConfig, action_horizon: int, batch_size: int, *, sharding: jax.sharding.Sharding | None = None, skip_norm_stats: bool = False, shuffle: bool = False, num_batches: int | None = None, framework: str = "jax", ) -> DataLoader[tuple[_model.Observation, _model.Actions]]: """Create an RLDS data loader for training. Note: This data loader requires some extra dependencies -- see examples/droid/README_train.md Args: data_config: The data configuration. action_horizon: The action horizon. batch_size: The batch size. sharding: The sharding to use for the data loader. If None, the data loader will use a single device sharding. skip_norm_stats: Whether to skip data normalization. shuffle: Whether to shuffle the data. num_batches: Determines the number of batches to return. If the number exceeds the number of batches in the dataset, the data loader will loop over the dataset. If not provided, will iterate over the dataset indefinitely. """ if framework == "pytorch": raise NotImplementedError("PyTorch RLDS data loader is not supported yet") dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle) dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True) data_loader = RLDSDataLoader( dataset, sharding=sharding, num_batches=num_batches, ) return DataLoaderImpl(data_config, data_loader) class TorchDataLoader: """Torch data loader implementation.""" def __init__( self, dataset, local_batch_size: int, *, sharding: jax.sharding.Sharding | None = None, shuffle: bool = False, sampler: torch.utils.data.Sampler | None = None, num_batches: int | None = None, num_workers: int = 0, seed: int = 0, framework: str = "jax", ): """Create a PyTorch data loader. Args: dataset: The dataset to load. local_batch_size: The local batch size for each process. sharding: The sharding to use for the data loader. shuffle: Whether to shuffle the data. num_batches: If provided, determines the number of returned batches. If the number is larger than the number of batches in the dataset, the data loader will loop over the dataset. If not provided, will iterate over the dataset indefinitely. num_workers: The number of worker processes to use. If zero, the data loader will execute in the main process. seed: The seed to use for shuffling the data. """ if jax.process_count() > 1: raise NotImplementedError("Data loading with multiple processes is not supported.") if len(dataset) < local_batch_size: raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).") # Store sharding - None for PyTorch, JAX sharding for JAX self._sharding = sharding if sharding is None and framework == "jax": # Use data parallel sharding by default for JAX only. self._sharding = jax.sharding.NamedSharding( jax.sharding.Mesh(jax.devices(), ("B",)), jax.sharding.PartitionSpec("B"), ) self._num_batches = num_batches mp_context = None if num_workers > 0: mp_context = multiprocessing.get_context("spawn") generator = torch.Generator() generator.manual_seed(seed) self._data_loader = torch.utils.data.DataLoader( typing.cast(torch.utils.data.Dataset, dataset), batch_size=local_batch_size, shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler sampler=sampler, num_workers=num_workers, multiprocessing_context=mp_context, persistent_workers=num_workers > 0, collate_fn=_collate_fn, worker_init_fn=_worker_init_fn, drop_last=True, generator=generator, ) @property def torch_loader(self) -> torch.utils.data.DataLoader: return self._data_loader def __iter__(self): num_items = 0 while True: data_iter = iter(self._data_loader) while True: if self._num_batches is not None and num_items >= self._num_batches: return try: batch = next(data_iter) except StopIteration: break # We've exhausted the dataset. Create a new iterator and start over. num_items += 1 # For JAX, convert to sharded arrays; for PyTorch, return torch tensors if self._sharding is not None: yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch) else: yield jax.tree.map(torch.as_tensor, batch) def _collate_fn(items): """Collate the batch elements into batched numpy arrays.""" # Make sure to convert to numpy arrays before stacking since some of the incoming elements # may be JAX arrays. return jax.tree.map(lambda *xs: np.stack([np.asarray(x) for x in xs], axis=0), *items) def _worker_init_fn(worker_id: int) -> None: """Tell JAX inside the worker process not to preallocate the GPU memory.""" # NOTE: This is called after jax is imported inside the worker process. This # means that this approach will not work for selecting the backend. os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" class RLDSDataLoader: """Shallow wrapper around the DROID data loader to make it compatible with openpi. All batching already happens in the DROID dataset, so we don't need to do anything here. """ def __init__( self, dataset: DroidRldsDataset, *, sharding: jax.sharding.Sharding | None = None, num_batches: int | None = None, ): self._dataset = dataset self._num_batches = num_batches if jax.process_count() > 1: raise NotImplementedError("Data loading with multiple processes is not supported.") if sharding is None: # Use data parallel sharding by default. sharding = jax.sharding.NamedSharding( jax.sharding.Mesh(jax.devices(), ("B",)), jax.sharding.PartitionSpec("B"), ) self._sharding = sharding self._num_batches = num_batches def __iter__(self): num_items = 0 while True: data_iter = iter(self._dataset) while True: if self._num_batches is not None and num_items >= self._num_batches: return try: batch = next(data_iter) except StopIteration: break # We've exhausted the dataset. Create a new iterator and start over. num_items += 1 yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch) class DataLoaderImpl(DataLoader): def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader): self._data_config = data_config self._data_loader = data_loader def data_config(self) -> _config.DataConfig: return self._data_config def __iter__(self): for batch in self._data_loader: yield _model.Observation.from_dict(batch), batch["actions"] ================================================ FILE: src/openpi/training/data_loader_test.py ================================================ import dataclasses import jax from openpi.models import pi0_config from openpi.training import config as _config from openpi.training import data_loader as _data_loader def test_torch_data_loader(): config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48) dataset = _data_loader.FakeDataset(config, 16) loader = _data_loader.TorchDataLoader( dataset, local_batch_size=4, num_batches=2, ) batches = list(loader) assert len(batches) == 2 for batch in batches: assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch)) def test_torch_data_loader_infinite(): config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48) dataset = _data_loader.FakeDataset(config, 4) loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4) data_iter = iter(loader) for _ in range(10): _ = next(data_iter) def test_torch_data_loader_parallel(): config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48) dataset = _data_loader.FakeDataset(config, 10) loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4, num_batches=2, num_workers=2) batches = list(loader) assert len(batches) == 2 for batch in batches: assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch)) def test_with_fake_dataset(): config = _config.get_config("debug") loader = _data_loader.create_data_loader(config, skip_norm_stats=True, num_batches=2) batches = list(loader) assert len(batches) == 2 for batch in batches: assert all(x.shape[0] == config.batch_size for x in jax.tree.leaves(batch)) for _, actions in batches: assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim) def test_with_real_dataset(): config = _config.get_config("pi0_aloha_sim") config = dataclasses.replace(config, batch_size=4) loader = _data_loader.create_data_loader( config, # Skip since we may not have the data available. skip_norm_stats=True, num_batches=2, shuffle=True, ) # Make sure that we can get the data config. assert loader.data_config().repo_id == config.data.repo_id batches = list(loader) assert len(batches) == 2 for _, actions in batches: assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim) ================================================ FILE: src/openpi/training/droid_rlds_dataset.py ================================================ """ RLDS-based data loader for DROID. While openpi typically uses LeRobot's data loader, it is not currently scalable enough for larger datasets like DROID. Thus, we provide a data loader example here that uses the RLDS data format. The data loader also applies a few DROID-specific data filters / transformations. """ from collections.abc import Sequence import dataclasses from enum import Enum from enum import auto import json import logging from pathlib import Path import tqdm import openpi.shared.download as download class DroidActionSpace(Enum): """Action space for DROID dataset.""" JOINT_POSITION = auto() JOINT_VELOCITY = auto() @dataclasses.dataclass class RLDSDataset: name: str version: str weight: float filter_dict_path: str | None = None class DroidRldsDataset: def __init__( self, data_dir: str, batch_size: int, datasets: Sequence[RLDSDataset], *, # Force keyword-only arguments shuffle: bool = True, action_chunk_size: int = 16, # We default to joint position actions, since they allow policy evaluation in simulation. action_space: DroidActionSpace = DroidActionSpace.JOINT_POSITION, max_loaded_steps_per_episode: int = 100, # Reduce this if you are running out of memory, but careful -- below ~100k shuffling is not sufficiently random. shuffle_buffer_size: int = 250_000, num_parallel_reads: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level num_parallel_calls: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level ): # Import tensorflow here to not make it mandatory in case RLDS data loader is not used. import dlimp as dl import tensorflow as tf import tensorflow_datasets as tfds # Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch / JAX) tf.config.set_visible_devices([], "GPU") # Ensure dataset weights sum to 1.0 assert sum(dataset.weight for dataset in datasets) == 1.0, "Dataset weights must sum to 1.0" def prepare_single_dataset(dataset_cfg: RLDSDataset): # ds_name, version = dataset_name.split(":") ds_name, version = dataset_cfg.name, dataset_cfg.version builder = tfds.builder(ds_name, data_dir=data_dir, version=version) dataset = dl.DLataset.from_rlds( builder, split="train", shuffle=shuffle, num_parallel_reads=num_parallel_reads ) # Filter out any unsuccessful trajectories -- we use the file name to check this dataset = dataset.filter( lambda traj: tf.strings.regex_full_match( traj["traj_metadata"]["episode_metadata"]["file_path"][0], ".*success.*" ) ) # Repeat dataset so we never run out of data. dataset = dataset.repeat() # Load the filter dictionary if provided. # The filter dictionary is a JSON file that maps episode keys to ranges of frames to sample # (e.g., # { # "": [[0, 100], [200, 300]] # } # means keep frames 0-99 and 200-299). filter_dict_path = dataset_cfg.filter_dict_path if filter_dict_path is not None: cached_filter_dict_path = download.maybe_download(filter_dict_path) with Path(cached_filter_dict_path).open("r") as f: filter_dict = json.load(f) logging.info(f"Using filter dictionary with {len(filter_dict)} episodes") keys_tensor = [] values_tensor = [] for episode_key, ranges in tqdm.tqdm(filter_dict.items(), desc="Creating idle filter hash table..."): for start, end in ranges: for t in range(start, end): frame_key = f"{episode_key}--{t}" keys_tensor.append(frame_key) values_tensor.append(True) self.filter_table = tf.lookup.StaticHashTable( tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor), default_value=False ) logging.info("Filter hash table initialized") else: self.filter_table = tf.lookup.StaticHashTable( tf.lookup.KeyValueTensorInitializer([""], [True]), default_value=True ) def restructure(traj): """Reformat observation and action keys, sample language instruction.""" # Important: we use joint *position* action space -- easier to simulate! actions = tf.concat( ( ( traj["action_dict"]["joint_position"] if action_space == DroidActionSpace.JOINT_POSITION else traj["action_dict"]["joint_velocity"] ), traj["action_dict"]["gripper_position"], ), axis=-1, ) # Randomly samples one of the two exterior images in DROID during training (we only train with one at a time). # Note: the "left" refers to the left camera in the stereo pair, we only train on the left camera. exterior_img = tf.cond( tf.random.uniform(shape=[]) > 0.5, lambda: traj["observation"]["exterior_image_1_left"], lambda: traj["observation"]["exterior_image_2_left"], ) wrist_img = traj["observation"]["wrist_image_left"] # Randomly sample one of the three language instructions instruction = tf.random.shuffle( [traj["language_instruction"], traj["language_instruction_2"], traj["language_instruction_3"]] )[0] traj_len = tf.shape(traj["action"])[0] indices = tf.as_string(tf.range(traj_len)) # Data filtering: # Compute a uniquely-identifying step ID by concatenating the recording folderpath, file path, # and each step's time step index. This will index into the filter hash table, and if it returns true, # then the frame passes the filter. step_id = ( traj["traj_metadata"]["episode_metadata"]["recording_folderpath"] + "--" + traj["traj_metadata"]["episode_metadata"]["file_path"] + "--" + indices ) passes_filter = self.filter_table.lookup(step_id) return { "actions": actions, "observation": { "image": exterior_img, "wrist_image": wrist_img, "joint_position": traj["observation"]["joint_position"], "gripper_position": traj["observation"]["gripper_position"], }, "prompt": instruction, "step_id": step_id, "passes_filter": passes_filter, } dataset = dataset.traj_map(restructure, num_parallel_calls) def chunk_actions(traj): """Splits episode into action chunks.""" traj_len = tf.shape(traj["actions"])[0] # For each step in the trajectory, construct indices for the next n actions action_chunk_indices = tf.broadcast_to( tf.range(action_chunk_size)[None], [traj_len, action_chunk_size], ) + tf.broadcast_to( tf.range(traj_len)[:, None], [traj_len, action_chunk_size], ) # Cap to length of the sequence --> final chunks will repeat the last action # This makes sense, since we are using absolute joint + gripper position actions action_chunk_indices = tf.minimum(action_chunk_indices, traj_len - 1) # Gather the actions for each chunk traj["actions"] = tf.gather(traj["actions"], action_chunk_indices) return traj dataset = dataset.traj_map(chunk_actions, num_parallel_calls) # Flatten: map from trajectory dataset to dataset of individual action chunks dataset = dataset.flatten(num_parallel_calls=num_parallel_calls) # Filter data that doesn't pass the filter def filter_from_dict(frame): return frame["passes_filter"] dataset = dataset.filter(filter_from_dict) # Remove "passes_filter" key from output def remove_passes_filter(frame): frame.pop("passes_filter") return frame dataset = dataset.map(remove_passes_filter) # Decode images: RLDS saves encoded images, only decode now for efficiency def decode_images(traj): traj["observation"]["image"] = tf.io.decode_image( traj["observation"]["image"], expand_animations=False, dtype=tf.uint8 ) traj["observation"]["wrist_image"] = tf.io.decode_image( traj["observation"]["wrist_image"], expand_animations=False, dtype=tf.uint8 ) return traj return dataset.frame_map(decode_images, num_parallel_calls) logging.info(f"Preparing {len(datasets)} datasets...") logging.info("-" * 50) for dataset in datasets: logging.info(f" {dataset.name}:{dataset.version} with weight {dataset.weight:.2f}") logging.info("-" * 50) all_datasets = [prepare_single_dataset(dataset) for dataset in datasets] weights = [dataset.weight for dataset in datasets] final_dataset = dl.DLataset.sample_from_datasets(all_datasets, weights=weights) final_dataset = final_dataset.shuffle(shuffle_buffer_size) final_dataset = final_dataset.batch(batch_size) # Note =>> Seems to reduce memory usage without affecting speed? final_dataset = final_dataset.with_ram_budget(1) self.dataset = final_dataset self.batch_size = batch_size self.shuffle = shuffle def __iter__(self): yield from self.dataset.as_numpy_iterator() def __len__(self): # This is the approximate number of samples in DROID after filtering. # Easier to hardcode than to iterate through the dataset and compute it. return 20_000_000 ================================================ FILE: src/openpi/training/misc/polaris_config.py ================================================ """PolaRiS baseline policy configs.""" from typing import TypeAlias import openpi.models.model as _model import openpi.models.pi0_config as pi0_config import openpi.models.pi0_fast as pi0_fast import openpi.models.tokenizer as _tokenizer import openpi.policies.droid_policy as droid_policy import openpi.training.droid_rlds_dataset as droid_rlds_dataset import openpi.training.optimizer as _optimizer import openpi.training.weight_loaders as weight_loaders import openpi.transforms as _transforms ModelType: TypeAlias = _model.ModelType def get_polaris_configs(): # Import here to avoid circular imports. from openpi.training.config import AssetsConfig from openpi.training.config import RLDSDroidDataConfig from openpi.training.config import SimpleDataConfig from openpi.training.config import TrainConfig return [ # # PolaRiS DROID jointpos policies # TrainConfig( name="pi05_droid_jointpos_polaris", model=pi0_config.Pi0Config(action_horizon=15, pi05=True), data=RLDSDroidDataConfig( assets=AssetsConfig( assets_dir="gs://openpi-assets/checkpoints/polaris/pi05_droid_jointpos_polaris/assets", asset_id="droid", ), datasets=( droid_rlds_dataset.RLDSDataset( name="droid", version="1.0.1", weight=0.9, filter_dict_path="gs://openpi-assets/droid/droid_sample_ranges_v1_0_1.json", ), droid_rlds_dataset.RLDSDataset( name="polaris_droid_cotrain_dataset", version="1.0.0", weight=0.1, filter_dict_path="gs://openpi-assets/droid/polaris_droid_cotrain_dataset_sample_ranges_v1_0_0.json", ), ), rlds_data_dir="", action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION, ), weight_loader=weight_loaders.CheckpointWeightLoader( "gs://openpi-assets/checkpoints/polaris/pi05_droid_jointpos_polaris/params" ), lr_schedule=_optimizer.CosineDecaySchedule( warmup_steps=1_000, peak_lr=5e-5, decay_steps=1_000_000, decay_lr=5e-5, ), num_train_steps=1_000, batch_size=128, log_interval=100, save_interval=1000, keep_period=1000, num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally ), TrainConfig( name="pi0_fast_droid_jointpos_polaris", model=pi0_fast.Pi0FASTConfig( action_dim=8, action_horizon=10, max_token_len=180, ), data=RLDSDroidDataConfig( assets=AssetsConfig( assets_dir="gs://openpi-assets/checkpoints/polaris/pi0_fast_droid_jointpos_polaris/assets", asset_id="droid", ), datasets=( droid_rlds_dataset.RLDSDataset( name="droid", version="1.0.1", weight=0.9, filter_dict_path="gs://openpi-assets/droid/droid_sample_ranges_v1_0_1.json", ), droid_rlds_dataset.RLDSDataset( name="polaris_droid_cotrain_dataset", version="1.0.0", weight=0.1, filter_dict_path="gs://openpi-assets/droid/polaris_droid_cotrain_dataset_sample_ranges_v1_0_0.json", ), ), rlds_data_dir="", action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION, ), weight_loader=weight_loaders.CheckpointWeightLoader( "gs://openpi-assets/checkpoints/polaris/pi0_fast_droid_jointpos_polaris/params" ), lr_schedule=_optimizer.CosineDecaySchedule( warmup_steps=1_000, peak_lr=5e-5, decay_steps=1_000_000, decay_lr=5e-5, ), num_train_steps=1_000, batch_size=128, log_interval=100, save_interval=1000, keep_period=1000, num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally ), TrainConfig( name="pi0_droid_jointpos_polaris", model=pi0_config.Pi0Config( # action_dim=8, # leave as 32 default... action_horizon=10, max_token_len=100, ), data=RLDSDroidDataConfig( assets=AssetsConfig( assets_dir="gs://openpi-assets/checkpoints/polaris/pi0_droid_jointpos_polaris/assets", asset_id="droid", ), datasets=( droid_rlds_dataset.RLDSDataset( name="droid", version="1.0.1", weight=0.9, filter_dict_path="gs://openpi-assets/droid/droid_sample_ranges_v1_0_1.json", ), droid_rlds_dataset.RLDSDataset( name="polaris_droid_cotrain_dataset", version="1.0.0", weight=0.1, filter_dict_path="gs://openpi-assets/droid/polaris_droid_cotrain_dataset_sample_ranges_v1_0_0.json", ), ), rlds_data_dir="", action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION, ), weight_loader=weight_loaders.CheckpointWeightLoader( "gs://openpi-assets/checkpoints/polaris/pi0_droid_jointpos_polaris/params" ), lr_schedule=_optimizer.CosineDecaySchedule( warmup_steps=1_000, peak_lr=5e-5, decay_steps=1_000_000, decay_lr=5e-5, ), num_train_steps=1_000, batch_size=128, log_interval=100, save_interval=1000, keep_period=1000, num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally ), TrainConfig( name="pi0_droid_jointpos_100k_polaris", model=pi0_config.Pi0Config( # action_dim=8, # leave as 32 default... action_horizon=10, max_token_len=100, ), data=RLDSDroidDataConfig( assets=AssetsConfig( assets_dir="gs://openpi-assets/checkpoints/polaris/pi0_droid_jointpos_100k_polaris/assets", asset_id="droid", ), datasets=( droid_rlds_dataset.RLDSDataset( name="droid", version="1.0.1", weight=0.9, filter_dict_path="gs://openpi-assets/droid/droid_sample_ranges_v1_0_1.json", ), droid_rlds_dataset.RLDSDataset( name="polaris_droid_cotrain_dataset", version="1.0.0", weight=0.1, filter_dict_path="gs://openpi-assets/droid/polaris_droid_cotrain_dataset_sample_ranges_v1_0_0.json", ), ), rlds_data_dir="", action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION, ), weight_loader=weight_loaders.CheckpointWeightLoader( "gs://openpi-assets/checkpoints/polaris/pi0_droid_jointpos_100k_polaris/params" ), lr_schedule=_optimizer.CosineDecaySchedule( warmup_steps=1_000, peak_lr=5e-5, decay_steps=1_000_000, decay_lr=5e-5, ), num_train_steps=1_000, batch_size=128, log_interval=100, save_interval=1000, keep_period=1000, num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally ), # openpi doesn't support finetuning of binning policies, so this is an inference-only config TrainConfig( name="paligemma_binning_droid_jointpos", model=pi0_fast.Pi0FASTConfig( action_dim=8, action_horizon=15, max_token_len=600, fast_model_tokenizer=_tokenizer.BinningTokenizer, ), data=SimpleDataConfig( assets=AssetsConfig(asset_id="droid"), data_transforms=lambda model: _transforms.Group( inputs=[droid_policy.DroidInputs(model_type=ModelType.PI0_FAST)], outputs=[ _transforms.AbsoluteActions(_transforms.make_bool_mask(7, -1)), droid_policy.DroidOutputs(), ], ), ), ), ] ================================================ FILE: src/openpi/training/misc/roboarena_config.py ================================================ """RoboArena baseline policy configs.""" from typing import TypeAlias import openpi.models.model as _model import openpi.models.pi0_config as pi0_config import openpi.models.pi0_fast as pi0_fast import openpi.models.tokenizer as _tokenizer import openpi.policies.droid_policy as droid_policy import openpi.transforms as _transforms ModelType: TypeAlias = _model.ModelType def get_roboarena_configs(): # Import here to avoid circular imports. from openpi.training.config import AssetsConfig from openpi.training.config import DataConfig from openpi.training.config import SimpleDataConfig from openpi.training.config import TrainConfig return [ # # RoboArena DROID baseline inference configs. # TrainConfig( # Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer. name="paligemma_binning_droid", model=pi0_fast.Pi0FASTConfig( action_dim=8, action_horizon=15, max_token_len=400, fast_model_tokenizer=_tokenizer.BinningTokenizer, ), data=SimpleDataConfig( assets=AssetsConfig(asset_id="droid"), data_transforms=lambda model: _transforms.Group( inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)], outputs=[droid_policy.DroidOutputs()], ), base_config=DataConfig( prompt_from_task=True, ), ), ), TrainConfig( # Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer). name="paligemma_fast_droid", model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15), data=SimpleDataConfig( assets=AssetsConfig(asset_id="droid"), data_transforms=lambda model: _transforms.Group( inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)], outputs=[droid_policy.DroidOutputs()], ), base_config=DataConfig( prompt_from_task=True, ), ), ), TrainConfig( # Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset). name="paligemma_fast_specialist_droid", model=pi0_fast.Pi0FASTConfig( action_dim=8, action_horizon=15, fast_model_tokenizer=_tokenizer.FASTTokenizer, fast_model_tokenizer_kwargs={"fast_tokenizer_path": "KarlP/fast_droid_specialist"}, ), data=SimpleDataConfig( assets=AssetsConfig(asset_id="droid"), data_transforms=lambda model: _transforms.Group( inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)], outputs=[droid_policy.DroidOutputs()], ), base_config=DataConfig( prompt_from_task=True, ), ), ), TrainConfig( # Trained from PaliGemma, using FSQ tokenizer. name="paligemma_vq_droid", model=pi0_fast.Pi0FASTConfig( action_dim=8, action_horizon=15, fast_model_tokenizer=_tokenizer.FSQTokenizer, fast_model_tokenizer_kwargs={"fsq_tokenizer_path": "gs://openpi-assets/tokenizers/droid_fsq_tokenizer"}, ), data=SimpleDataConfig( assets=AssetsConfig(asset_id="droid"), data_transforms=lambda model: _transforms.Group( inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)], outputs=[droid_policy.DroidOutputs()], ), base_config=DataConfig( prompt_from_task=True, ), ), ), TrainConfig( # pi0-style diffusion / flow VLA, trained on DROID from PaliGemma. name="paligemma_diffusion_droid", model=pi0_config.Pi0Config(action_horizon=10, action_dim=8), data=SimpleDataConfig( assets=AssetsConfig(asset_id="droid"), data_transforms=lambda model: _transforms.Group( inputs=[droid_policy.DroidInputs(action_dim=model.action_dim)], outputs=[droid_policy.DroidOutputs()], ), base_config=DataConfig( prompt_from_task=True, ), ), ), ] ================================================ FILE: src/openpi/training/optimizer.py ================================================ import dataclasses from typing import Protocol, runtime_checkable import jax.numpy as jnp import optax import openpi.shared.array_typing as at @runtime_checkable class LRScheduleConfig(Protocol): def create(self) -> optax.Schedule: ... @dataclasses.dataclass(frozen=True) class CosineDecaySchedule(LRScheduleConfig): """Cosine decay schedule with warmup.""" warmup_steps: int = 1_000 peak_lr: float = 2.5e-5 decay_steps: int = 30_000 decay_lr: float = 2.5e-6 def create(self) -> optax.Schedule: return optax.warmup_cosine_decay_schedule( init_value=self.peak_lr / (self.warmup_steps + 1), peak_value=self.peak_lr, warmup_steps=self.warmup_steps, decay_steps=self.decay_steps, end_value=self.decay_lr, ) @dataclasses.dataclass(frozen=True) class RsqrtDecaySchedule(LRScheduleConfig): """Inverse square root decay schedule with warmup.""" warmup_steps: int = 1_000 peak_lr: float = 5e-5 timescale: float = 10_000 def create(self) -> optax.Schedule: return optax.join_schedules( [ optax.linear_schedule( init_value=self.peak_lr / (self.warmup_steps + 1), end_value=self.peak_lr, transition_steps=self.warmup_steps, ), lambda step: self.peak_lr / jnp.sqrt((self.timescale + step) / self.timescale), ], [self.warmup_steps], ) @runtime_checkable class OptimizerConfig(Protocol): def create( self, lr: optax.ScalarOrSchedule, weight_decay_mask: at.PyTree | None = None, ) -> optax.GradientTransformation: ... @dataclasses.dataclass(frozen=True) class AdamW(OptimizerConfig): """AdamW optimizer.""" b1: float = 0.9 b2: float = 0.95 eps: float = 1e-8 # Changing this to 0 can cause out-of-memory errors for some reason, so we set it to a negligible value. weight_decay: float = 1e-10 clip_gradient_norm: float = 1.0 def create( self, lr: optax.ScalarOrSchedule, weight_decay_mask: at.PyTree | None = None, ) -> optax.GradientTransformation: tx = optax.adamw( lr, b1=self.b1, b2=self.b2, eps=self.eps, weight_decay=self.weight_decay, mask=weight_decay_mask ) return optax.chain(optax.clip_by_global_norm(self.clip_gradient_norm), tx) @dataclasses.dataclass(frozen=True) class SGD(OptimizerConfig): """SGD optimizer.""" lr: float = 5e-5 momentum: float = 0.9 nesterov: bool = False def create( self, lr: optax.ScalarOrSchedule, weight_decay_mask: at.PyTree | None = None, ) -> optax.GradientTransformation: assert weight_decay_mask is None, "Weight decay is not supported for SGD" return optax.sgd(lr, momentum=self.momentum, nesterov=self.nesterov) def create_optimizer( optimizer: OptimizerConfig, lr_schedule: LRScheduleConfig, weight_decay_mask: at.PyTree | None = None ) -> optax.GradientTransformation: lr = lr_schedule.create() return optimizer.create(lr, weight_decay_mask=weight_decay_mask) ================================================ FILE: src/openpi/training/sharding.py ================================================ import contextlib import logging import jax import numpy as np BATCH_AXIS = "batch" FSDP_AXIS = "fsdp" # In FSDP, we shard the data across both the batch and FSDP axes. DATA_AXIS = (BATCH_AXIS, FSDP_AXIS) class _MeshState: active_mesh: jax.sharding.Mesh | None = None def make_mesh(num_fsdp_devices: int) -> jax.sharding.Mesh: if jax.device_count() % num_fsdp_devices != 0: raise ValueError( f"Number of devices {jax.device_count()} must be divisible by the number of FSDP devices {num_fsdp_devices}." ) mesh_shape = (jax.device_count() // num_fsdp_devices, num_fsdp_devices) return jax.make_mesh(mesh_shape, (BATCH_AXIS, FSDP_AXIS)) @contextlib.contextmanager def set_mesh(mesh: jax.sharding.Mesh): """Plumbing the mesh deep into the module tree is extremely cumbersome; until the JAX team lands a better API, a custom context manager like this one is the recommended way to maintain a reference to a global mesh. This is only used in `activation_sharding_constraint` below.""" if _MeshState.active_mesh is not None: raise ValueError("Cannot nest set_mesh context managers.") _MeshState.active_mesh = mesh try: yield finally: _MeshState.active_mesh = None def activation_sharding_constraint(pytree): if _MeshState.active_mesh is None: return pytree return jax.lax.with_sharding_constraint( pytree, jax.sharding.NamedSharding(_MeshState.active_mesh, jax.sharding.PartitionSpec(DATA_AXIS)) ) def fsdp_sharding( pytree, mesh: jax.sharding.Mesh, *, min_size_mbytes: int = 4, # 4 MiB log: bool = False, ): """Apply FSDP sharding to a pytree of arrays based on the mesh shape. Args: pytree: A pytree to be apply sharding specified by the mesh, note that only array types (eg. contains .shape attr) will be considered for sharding. mesh: The mesh being used for applying sharding on to pytree. min_size_mbytes: The minimum size of the array in MiB to be considered for sharding, any array smaller than this will be replicated. log: If true, will log the sharding decisions for arrays that are being considered for sharding. Returns: The sharded pytree. """ min_size_bytes = min_size_mbytes * 2**20 def _shard_arr(kp, array: jax.ShapeDtypeStruct): # if fsdp is not actually going to be used, replicate everything to avoid extraneous logging if mesh.shape[FSDP_AXIS] == 1: return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) # replicate scalar and vector arrays if not hasattr(array, "shape"): return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) if len(array.shape) < 2: return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) # replicate small arrays if (arr_size := np.prod(array.shape) * np.dtype(array.dtype).itemsize) < min_size_bytes: return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) # shard matrices and larger tensors along the largest axis that is divisible by the fsdp dimension axes = np.argsort(array.shape)[::-1] spec = [None] * len(axes) for i in axes: if array.shape[i] % mesh.shape[FSDP_AXIS] == 0: if log: logging.info( f"Sharding {jax.tree_util.keystr(kp)} of shape {array.shape} ({arr_size / 2**20:.2f} MiB) along axis {i}" ) spec[i] = FSDP_AXIS return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*spec)) # replicate if no valid sharding was found if log: logging.warning( f"Could not find a valid sharding for {jax.tree_util.keystr(kp)} of shape {array.shape} with mesh of shape {mesh.shape}" ) return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) return jax.tree_util.tree_map_with_path(_shard_arr, pytree) ================================================ FILE: src/openpi/training/utils.py ================================================ from collections.abc import Callable from typing import Any from flax import nnx from flax import struct import jax import optax from openpi.models import model as _model from openpi.shared import array_typing as at @at.typecheck @struct.dataclass class TrainState: step: at.Int[at.ArrayLike, ""] params: nnx.State model_def: nnx.GraphDef[_model.BaseModel] opt_state: optax.OptState tx: optax.GradientTransformation = struct.field(pytree_node=False) ema_decay: float | None = struct.field(pytree_node=False) ema_params: nnx.State | None = None @at.typecheck def tree_to_info(tree: at.PyTree, interp_func: Callable[[Any], str] = str) -> str: """Converts a PyTree into a human-readable string for logging. Optionally, `interp_func` can be provided to convert the leaf values to more meaningful strings. """ tree, _ = jax.tree_util.tree_flatten_with_path(tree) return "\n".join(f"{jax.tree_util.keystr(path)}: {interp_func(value)}" for path, value in tree) @at.typecheck def array_tree_to_info(tree: at.PyTree) -> str: """Converts a PyTree of arrays into a human-readable string for logging.""" return tree_to_info(tree, lambda x: f"{x.shape}@{x.dtype}") ================================================ FILE: src/openpi/training/weight_loaders.py ================================================ import dataclasses import logging import re from typing import Protocol, runtime_checkable import flax.traverse_util import numpy as np import openpi.models.model as _model import openpi.shared.array_typing as at import openpi.shared.download as download logger = logging.getLogger(__name__) @runtime_checkable class WeightLoader(Protocol): def load(self, params: at.Params) -> at.Params: """Loads the model weights. Args: params: Parameters of the model. This is a nested structure of array-like objects that represent the model's parameters. Returns: Loaded parameters. The structure must be identical to `params`. If returning a subset of the parameters the loader must merge the loaded parameters with `params`. """ @dataclasses.dataclass(frozen=True) class NoOpWeightLoader(WeightLoader): def load(self, params: at.Params) -> at.Params: return params @dataclasses.dataclass(frozen=True) class CheckpointWeightLoader(WeightLoader): """Loads an entire set of weights from a checkpoint. Compatible with: trained checkpoints: example: "./checkpoints////params" released checkpoints: example: "gs://openpi-assets/checkpoints//params" """ params_path: str def load(self, params: at.Params) -> at.Params: # We are loading np.ndarray and relying on the training code to properly convert and shard the params. loaded_params = _model.restore_params(download.maybe_download(self.params_path), restore_type=np.ndarray) # Add all missing LoRA weights. return _merge_params(loaded_params, params, missing_regex=".*lora.*") @dataclasses.dataclass(frozen=True) class PaliGemmaWeightLoader(WeightLoader): """Loads weights from the official PaliGemma checkpoint. This will overwrite existing weights with similar names while keeping all extra weights intact. This allows us to support the action expert which is used by the Pi0 model. """ def load(self, params: at.Params) -> at.Params: path = download.maybe_download( "gs://vertex-model-garden-paligemma-us/paligemma/pt_224.npz", gs={"token": "anon"} ) with path.open("rb") as f: flat_params = dict(np.load(f, allow_pickle=False)) loaded_params = {"PaliGemma": flax.traverse_util.unflatten_dict(flat_params, sep="/")["params"]} # Add all missing weights. return _merge_params(loaded_params, params, missing_regex=".*") def _merge_params(loaded_params: at.Params, params: at.Params, *, missing_regex: str) -> at.Params: """Merges the loaded parameters with the reference parameters. Args: loaded_params: The parameters to merge. params: The reference parameters. missing_regex: A regex pattern for all missing keys that should be merged from the reference parameters. Returns: A new dictionary with the merged parameters. """ flat_ref = flax.traverse_util.flatten_dict(params, sep="/") flat_loaded = flax.traverse_util.flatten_dict(loaded_params, sep="/") # First, take all weights that are a subset of the reference weights. result = {} for k, v in flat_loaded.items(): if k in flat_ref: result[k] = v.astype(flat_ref[k].dtype) if v.dtype != flat_ref[k].dtype else v flat_loaded.clear() # Then, merge any missing weights as defined by the missing regex. pattern = re.compile(missing_regex) for k in {k for k in flat_ref if pattern.fullmatch(k)}: if k not in result: result[k] = flat_ref[k] return flax.traverse_util.unflatten_dict(result, sep="/") ================================================ FILE: src/openpi/transforms.py ================================================ from collections.abc import Callable, Mapping, Sequence import dataclasses import re from typing import Protocol, TypeAlias, TypeVar, runtime_checkable import flax.traverse_util as traverse_util import jax import numpy as np from openpi_client import image_tools from openpi.models import tokenizer as _tokenizer from openpi.shared import array_typing as at from openpi.shared import normalize as _normalize DataDict: TypeAlias = at.PyTree NormStats: TypeAlias = _normalize.NormStats T = TypeVar("T") S = TypeVar("S") @runtime_checkable class DataTransformFn(Protocol): def __call__(self, data: DataDict) -> DataDict: """Apply transformation to the data. Args: data: The data to apply the transform to. This is a possibly nested dictionary that contains unbatched data elements. Each leaf is expected to be a numpy array. Using JAX arrays is allowed but not recommended since it may result in extra GPU memory usage inside data loader worker processes. Returns: The transformed data. Could be the input `data` that was modified in place, or a new data structure. """ @dataclasses.dataclass(frozen=True) class Group: """A group of transforms.""" # Transforms that are applied to the model input data. inputs: Sequence[DataTransformFn] = () # Transforms that are applied to the model output data. outputs: Sequence[DataTransformFn] = () def push(self, *, inputs: Sequence[DataTransformFn] = (), outputs: Sequence[DataTransformFn] = ()) -> "Group": """Append transforms to the group and return a new group. Args: inputs: Appended to the *end* of the current input transforms. outputs: Appended to the *beginning* of the current output transforms. Returns: A new group with the appended transforms. """ return Group(inputs=(*self.inputs, *inputs), outputs=(*outputs, *self.outputs)) @dataclasses.dataclass(frozen=True) class CompositeTransform(DataTransformFn): """A composite transform that applies a sequence of transforms in order.""" transforms: Sequence[DataTransformFn] def __call__(self, data: DataDict) -> DataDict: for transform in self.transforms: data = transform(data) return data def compose(transforms: Sequence[DataTransformFn]) -> DataTransformFn: """Compose a sequence of transforms into a single transform.""" return CompositeTransform(transforms) @dataclasses.dataclass(frozen=True) class RepackTransform(DataTransformFn): """Repacks an input dictionary into a new dictionary. Repacking is defined using a dictionary where the keys are the new keys and the values are the flattened paths to the old keys. We use '/' as the separator during flattening. Example: { "images": { "cam_high": "observation.images.top", "cam_low": "observation.images.bottom", }, "state": "observation.state", "actions": "action", } """ structure: at.PyTree[str] def __call__(self, data: DataDict) -> DataDict: flat_item = flatten_dict(data) return jax.tree.map(lambda k: flat_item[k], self.structure) @dataclasses.dataclass(frozen=True) class InjectDefaultPrompt(DataTransformFn): prompt: str | None def __call__(self, data: DataDict) -> DataDict: if self.prompt is not None and "prompt" not in data: data["prompt"] = np.asarray(self.prompt) return data @dataclasses.dataclass(frozen=True) class Normalize(DataTransformFn): norm_stats: at.PyTree[NormStats] | None # If true, will use quantile normalization. Otherwise, normal z-score normalization will be used. use_quantiles: bool = False # If true, will raise an error if any of the keys in the norm stats are not present in the data. strict: bool = False def __post_init__(self): if self.norm_stats is not None and self.use_quantiles: _assert_quantile_stats(self.norm_stats) def __call__(self, data: DataDict) -> DataDict: if self.norm_stats is None: return data return apply_tree( data, self.norm_stats, self._normalize_quantile if self.use_quantiles else self._normalize, strict=self.strict, ) def _normalize(self, x, stats: NormStats): mean, std = stats.mean[..., : x.shape[-1]], stats.std[..., : x.shape[-1]] return (x - mean) / (std + 1e-6) def _normalize_quantile(self, x, stats: NormStats): assert stats.q01 is not None assert stats.q99 is not None q01, q99 = stats.q01[..., : x.shape[-1]], stats.q99[..., : x.shape[-1]] return (x - q01) / (q99 - q01 + 1e-6) * 2.0 - 1.0 @dataclasses.dataclass(frozen=True) class Unnormalize(DataTransformFn): norm_stats: at.PyTree[NormStats] | None # If true, will use quantile normalization. Otherwise, normal z-score normalization will be used. use_quantiles: bool = False def __post_init__(self): if self.norm_stats is not None and self.use_quantiles: _assert_quantile_stats(self.norm_stats) def __call__(self, data: DataDict) -> DataDict: if self.norm_stats is None: return data # Make sure that all the keys in the norm stats are present in the data. return apply_tree( data, self.norm_stats, self._unnormalize_quantile if self.use_quantiles else self._unnormalize, strict=True, ) def _unnormalize(self, x, stats: NormStats): mean = pad_to_dim(stats.mean, x.shape[-1], axis=-1, value=0.0) std = pad_to_dim(stats.std, x.shape[-1], axis=-1, value=1.0) return x * (std + 1e-6) + mean def _unnormalize_quantile(self, x, stats: NormStats): assert stats.q01 is not None assert stats.q99 is not None q01, q99 = stats.q01, stats.q99 if (dim := q01.shape[-1]) < x.shape[-1]: return np.concatenate([(x[..., :dim] + 1.0) / 2.0 * (q99 - q01 + 1e-6) + q01, x[..., dim:]], axis=-1) return (x + 1.0) / 2.0 * (q99 - q01 + 1e-6) + q01 @dataclasses.dataclass(frozen=True) class ResizeImages(DataTransformFn): height: int width: int def __call__(self, data: DataDict) -> DataDict: data["image"] = {k: image_tools.resize_with_pad(v, self.height, self.width) for k, v in data["image"].items()} return data @dataclasses.dataclass(frozen=True) class SubsampleActions(DataTransformFn): stride: int def __call__(self, data: DataDict) -> DataDict: data["actions"] = data["actions"][:: self.stride] return data @dataclasses.dataclass(frozen=True) class DeltaActions(DataTransformFn): """Repacks absolute actions into delta action space.""" # Boolean mask for the action dimensions to be repacked into delta action space. Length # can be smaller than the actual number of dimensions. If None, this transform is a no-op. # See `make_bool_mask` for more details. mask: Sequence[bool] | None def __call__(self, data: DataDict) -> DataDict: if "actions" not in data or self.mask is None: return data state, actions = data["state"], data["actions"] mask = np.asarray(self.mask) dims = mask.shape[-1] actions[..., :dims] -= np.expand_dims(np.where(mask, state[..., :dims], 0), axis=-2) data["actions"] = actions return data @dataclasses.dataclass(frozen=True) class AbsoluteActions(DataTransformFn): """Repacks delta actions into absolute action space.""" # Boolean mask for the action dimensions to be repacked into absolute action space. Length # can be smaller than the actual number of dimensions. If None, this transform is a no-op. # See `make_bool_mask` for more details. mask: Sequence[bool] | None def __call__(self, data: DataDict) -> DataDict: if "actions" not in data or self.mask is None: return data state, actions = data["state"], data["actions"] mask = np.asarray(self.mask) dims = mask.shape[-1] actions[..., :dims] += np.expand_dims(np.where(mask, state[..., :dims], 0), axis=-2) data["actions"] = actions return data @dataclasses.dataclass(frozen=True) class TokenizePrompt(DataTransformFn): tokenizer: _tokenizer.PaligemmaTokenizer discrete_state_input: bool = False def __call__(self, data: DataDict) -> DataDict: if (prompt := data.pop("prompt", None)) is None: raise ValueError("Prompt is required") if self.discrete_state_input: if (state := data.get("state", None)) is None: raise ValueError("State is required.") else: state = None if not isinstance(prompt, str): prompt = prompt.item() tokens, token_masks = self.tokenizer.tokenize(prompt, state) return {**data, "tokenized_prompt": tokens, "tokenized_prompt_mask": token_masks} @dataclasses.dataclass(frozen=True) class TokenizeFASTInputs(DataTransformFn): tokenizer: _tokenizer.FASTTokenizer def __call__(self, data: DataDict) -> DataDict: if (prompt := data.pop("prompt", None)) is None: raise ValueError("Prompt is required") if not isinstance(prompt, str): prompt = prompt.item() state, actions = data["state"], data.get("actions") tokens, token_mask, ar_mask, loss_mask = self.tokenizer.tokenize(prompt, state, actions) return { **data, "tokenized_prompt": tokens, "tokenized_prompt_mask": token_mask, "token_ar_mask": ar_mask, "token_loss_mask": loss_mask, } @dataclasses.dataclass(frozen=True) class ExtractFASTActions(DataTransformFn): tokenizer: _tokenizer.FASTTokenizer action_horizon: int action_dim: int def __call__(self, data: DataDict) -> DataDict: if "actions" not in data: return data # Model outputs are saved in "actions", but for FAST models they represent tokens. tokens = data.pop("actions") actions = self.tokenizer.extract_actions(tokens.astype(np.int32), self.action_horizon, self.action_dim) return { **data, "actions": actions, } @dataclasses.dataclass(frozen=True) class PromptFromLeRobotTask(DataTransformFn): """Extracts a prompt from the current LeRobot dataset task.""" # Contains the LeRobot dataset tasks (dataset.meta.tasks). tasks: dict[int, str] def __call__(self, data: DataDict) -> DataDict: if "task_index" not in data: raise ValueError('Cannot extract prompt without "task_index"') task_index = int(data["task_index"]) if (prompt := self.tasks.get(task_index)) is None: raise ValueError(f"{task_index=} not found in task mapping: {self.tasks}") return {**data, "prompt": prompt} @dataclasses.dataclass(frozen=True) class PadStatesAndActions(DataTransformFn): """Zero-pads states and actions to the model action dimension.""" model_action_dim: int def __call__(self, data: DataDict) -> DataDict: data["state"] = pad_to_dim(data["state"], self.model_action_dim, axis=-1) if "actions" in data: data["actions"] = pad_to_dim(data["actions"], self.model_action_dim, axis=-1) return data def flatten_dict(tree: at.PyTree) -> dict: """Flatten a nested dictionary. Uses '/' as the separator.""" return traverse_util.flatten_dict(tree, sep="/") def unflatten_dict(tree: dict) -> at.PyTree: """Unflatten a flattened dictionary. Assumes that '/' was used as a separator.""" return traverse_util.unflatten_dict(tree, sep="/") def transform_dict(patterns: Mapping[str, str | None], tree: at.PyTree) -> at.PyTree: """Transform the structure of a nested dictionary using a set of patterns. The transformation is defined using the `patterns` dictionary. The keys are the input keys that should be matched and the values are the new names inside the output dictionary. If the value is None, the input key is removed. Both keys and values should represent flattened paths using '/' as the separator. Keys can be regular expressions and values can include backreferences to the matched groups (see `re.sub` for more details). Note that the regular expression must match the entire key. The order inside the `patterns` dictionary is important. Only the first pattern that matches the input key will be used. See unit tests for more examples. Args: patterns: A mapping from old keys to new keys. tree: The nested dictionary to transform. Returns: The transformed nested dictionary. """ data = flatten_dict(tree) # Compile the patterns. compiled = {re.compile(k): v for k, v in patterns.items()} output = {} for k in data: for pattern, repl in compiled.items(): if pattern.fullmatch(k): new_k = pattern.sub(repl, k, count=1) if repl is not None else None break else: # Use the original key if no match is found. new_k = k if new_k is not None: if new_k in output: raise ValueError(f"Key '{new_k}' already exists in output") output[new_k] = data[k] # Validate the output structure to make sure that it can be unflattened. names = sorted(output) for i in range(len(names) - 1): name, next_name = names[i : i + 2] if next_name.startswith(name + "/"): raise ValueError(f"Leaf '{name}' aliases a node of '{next_name}'") return unflatten_dict(output) def apply_tree( tree: at.PyTree[T], selector: at.PyTree[S], fn: Callable[[T, S], T], *, strict: bool = False ) -> at.PyTree[T]: tree = flatten_dict(tree) selector = flatten_dict(selector) def transform(k: str, v: T) -> T: if k in selector: return fn(v, selector[k]) return v if strict: for k in selector: if k not in tree: raise ValueError(f"Selector key {k} not found in tree") return unflatten_dict({k: transform(k, v) for k, v in tree.items()}) def pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray: """Pad an array to the target dimension with zeros along the specified axis.""" current_dim = x.shape[axis] if current_dim < target_dim: pad_width = [(0, 0)] * len(x.shape) pad_width[axis] = (0, target_dim - current_dim) return np.pad(x, pad_width, constant_values=value) return x def make_bool_mask(*dims: int) -> tuple[bool, ...]: """Make a boolean mask for the given dimensions. Example: make_bool_mask(2, -2, 2) == (True, True, False, False, True, True) make_bool_mask(2, 0, 2) == (True, True, True, True) Args: dims: The dimensions to make the mask for. Returns: A tuple of booleans. """ result = [] for dim in dims: if dim > 0: result.extend([True] * (dim)) else: result.extend([False] * (-dim)) return tuple(result) def _assert_quantile_stats(norm_stats: at.PyTree[NormStats]) -> None: for k, v in flatten_dict(norm_stats).items(): if v.q01 is None or v.q99 is None: raise ValueError( f"quantile stats must be provided if use_quantile_norm is True. Key {k} is missing q01 or q99." ) ================================================ FILE: src/openpi/transforms_test.py ================================================ import numpy as np import pytest import openpi.models.tokenizer as _tokenizer import openpi.transforms as _transforms def test_repack_transform(): transform = _transforms.RepackTransform( structure={ "a": {"b": "b/c"}, "d": "e/f", } ) item = {"b": {"c": 1}, "e": {"f": 2}} assert transform(item) == {"a": {"b": 1}, "d": 2} def test_delta_actions(): item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])} transform = _transforms.DeltaActions(mask=[False, True]) transformed = transform(item) assert np.all(transformed["state"] == np.array([1, 2, 3])) assert np.all(transformed["actions"] == np.array([[3, 2, 5], [5, 4, 7]])) def test_delta_actions_noop(): item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])} # No-op when the mask is disabled. transform = _transforms.DeltaActions(mask=None) assert transform(item) is item # No-op when there are no actions in the input. del item["actions"] transform = _transforms.DeltaActions(mask=[True, False]) assert transform(item) is item def test_absolute_actions(): item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])} transform = _transforms.AbsoluteActions(mask=[False, True]) transformed = transform(item) assert np.all(transformed["state"] == np.array([1, 2, 3])) assert np.all(transformed["actions"] == np.array([[3, 6, 5], [5, 8, 7]])) def test_absolute_actions_noop(): item = {"state": np.array([1, 2, 3]), "actions": np.array([[3, 4, 5], [5, 6, 7]])} # No-op when the mask is disabled. transform = _transforms.AbsoluteActions(mask=None) assert transform(item) is item # No-op when there are no actions in the input. del item["actions"] transform = _transforms.AbsoluteActions(mask=[True, False]) assert transform(item) is item def test_make_bool_mask(): assert _transforms.make_bool_mask(2, -2, 2) == (True, True, False, False, True, True) assert _transforms.make_bool_mask(2, 0, 2) == (True, True, True, True) def test_tokenize_prompt(): tokenizer = _tokenizer.PaligemmaTokenizer(max_len=12) transform = _transforms.TokenizePrompt(tokenizer) data = transform({"prompt": "Hello, world!"}) tok_prompt, tok_mask = tokenizer.tokenize("Hello, world!") assert np.allclose(tok_prompt, data["tokenized_prompt"]) assert np.allclose(tok_mask, data["tokenized_prompt_mask"]) def test_tokenize_no_prompt(): transform = _transforms.TokenizePrompt(_tokenizer.PaligemmaTokenizer()) with pytest.raises(ValueError, match="Prompt is required"): transform({}) def test_transform_dict(): # Rename and remove keys. input = {"a": {"b": 1, "c": 2}} output = _transforms.transform_dict({"a/b": "a/c", "a/c": None}, input) assert output == {"a": {"c": 1}} # Raises and error since the renamed key conflicts with an existing key. with pytest.raises(ValueError, match="Key 'a/c' already exists in output"): _transforms.transform_dict({"a/b": "a/c"}, input) # Full match is required and so nothing will be removed. input = {"a": {"b": 1, "c": 2}} output = _transforms.transform_dict({"a": None}, input) assert output == input # The regex matches the entire key and so the entire input will be removed. input = {"a": {"b": 1, "c": 2}} output = _transforms.transform_dict({"a.+": None}, input) assert output == {} # Replace keys using backreferences. All leaves named 'c' are replaced with 'd'. input = {"a": {"b": 1, "c": 1}, "b": {"c": 2}} output = _transforms.transform_dict({"(.+)/c": r"\1/d"}, input) assert output == {"a": {"b": 1, "d": 1}, "b": {"d": 2}} def test_extract_prompt_from_task(): transform = _transforms.PromptFromLeRobotTask({1: "Hello, world!"}) data = transform({"task_index": 1}) assert data["prompt"] == "Hello, world!" with pytest.raises(ValueError, match="task_index=2 not found in task mapping"): transform({"task_index": 2})