[
  {
    "path": ".dockerignore",
    "content": ".venv\ncheckpoints\ndata\n"
  },
  {
    "path": ".github/CODEOWNERS",
    "content": "# The CODEOWNERS file defines individuals or teams that are automatically requested for\n# review when someone opens a pull request that modifies certain code. When a draft pull\n# request is marked as ready for review, code owners are automatically notified.\n#\n# See: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners\n#\n# This is a comment.\n# Each line is a file pattern followed by one or more owners.\n\n# Global owners.\n* @jimmyt857 @Michael-Equi @kvablack\n\nsrc/openpi/models/ @kvablack\nsrc/openpi/training/ @kvablack\n\nscripts/ @jimmyt857 @kvablack"
  },
  {
    "path": ".github/workflows/pre-commit.yml",
    "content": "name: pre-commit\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - \"*\"\njobs:\n  pre-commit:\n    runs-on: ubuntu-latest\n    env:\n      GIT_LFS_SKIP_SMUDGE: true\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v3\n      - uses: pre-commit/action@v3.0.1\n"
  },
  {
    "path": ".github/workflows/test.yml",
    "content": "name: Test\non:\n  pull_request:\n    branches:\n      - \"*\"\n\njobs:\n  run_tests:\n    name: Run Tests\n    runs-on: openpi-verylarge\n    env:\n      GIT_LFS_SKIP_SMUDGE: true\n    steps:\n      - uses: actions/checkout@v4\n\n      - name: Install FFmpeg dependencies\n        run: |\n          sudo apt-get update\n          sudo apt-get install -y ffmpeg libavcodec-dev libavformat-dev libavutil-dev\n\n      - name: Install uv\n        uses: astral-sh/setup-uv@v5\n\n      - name: Set up Python\n        run: uv python install\n\n      - name: Install the project\n        run: uv sync --all-extras --dev\n\n      - name: Run tests\n        run: uv run pytest --strict-markers -m \"not manual\"\n"
  },
  {
    "path": ".gitignore",
    "content": "# Data directories.\nassets/\ncheckpoints/\ndata/\nwandb/\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control\n.pdm.toml\n.pdm-python\n.pdm-build/\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/\n"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"third_party/aloha\"]\n\tpath = third_party/aloha\n\turl = https://github.com/Physical-Intelligence/aloha.git\n[submodule \"third_party/libero\"]\n\tpath = third_party/libero\n\turl = https://github.com/Lifelong-Robot-Learning/LIBERO.git\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "exclude: third_party/\n\nrepos:\n  - repo: https://github.com/astral-sh/uv-pre-commit\n    # uv version.\n    rev: 0.5.14\n    hooks:\n      - id: uv-lock\n  - repo: https://github.com/astral-sh/ruff-pre-commit\n    # Ruff version.\n    rev: v0.8.6\n    hooks:\n      # Run the linter.\n      - id: ruff\n        args: [--fix]\n      - id: ruff-format"
  },
  {
    "path": ".python-version",
    "content": "3.11"
  },
  {
    "path": ".vscode/settings.json",
    "content": "{\n    \"[python]\": {\n        \"editor.defaultFormatter\": \"charliermarsh.ruff\",\n        \"editor.formatOnSave\": true,\n    },\n    \"python.testing.pytestArgs\": [\n        \"src\"\n    ],\n    \"python.testing.unittestEnabled\": false,\n    \"python.testing.pytestEnabled\": true\n}"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing to openpi\n\nWe welcome contributions, improvements, and modifications. Everyone is welcome to use openpi in accordance to the [license](LICENSE). Contributors are also welcome to submit bug reports, feature requests, and pull requests. We can't promise to approve every pull request, and we are a small team with limited bandwidth to review all requests, but we'll give it our best effort. Specifics are described below.\n\n## Issues and feature requests\n\nYou are welcome to use the Github [discussion](https://github.com/Physical-Intelligence/openpi/discussions) feature if you would like to discuss something that is not directly reporting an issue or making a feature request. This is suitable for questions about how to use some aspect of openpi, or other topics.\n\nIf you found a bug or other issue, please first check that the issue was not already reported (use the search bar on Github under Issues). If the issue has not yet been reported, please include this information when filing a Github issue:\n\n- Your OS type and version and the version of Python you are using\n- Code that allows us to reproduce your bug, including all dependencies\n- Traceback of any exception\n- Any other information that would help us, such as a screenshot\n\nIn order for us to address any issue, we must be able to reproduce it, so if you encountered the issue after making modifications to openpi, please reproduce the issue without any other modifications and provide a code snippet that allows us to quickly reproduce the problem on `main`.\n\nIf you would like to submit a feature request, please check that the feature request does not already exist, and please provide the following information:\n\n- The motivation for the feature\n- A description of the problem you are trying to solve or your use case\n- Enough information for us to understand the nature of the request\n- Some information for how you intend to use it (this might help us in understanding the motivation!)\n\nWe can't promise to support every feature request, but it is helpful to us to know the use cases that you are interested in!\n\n## Submitting a pull request\n\nIf you implemented support for a new robot or environment, or some other new feature, we welcome pull requests (PRs) to openpi. We encourage you to create a [feature request](https://github.com/Physical-Intelligence/openpi/issues) or make a post on the [discussion](https://github.com/Physical-Intelligence/openpi/discussions) board before starting to work on your PR, if you would like to get a sense for whether we are likely to approve your PR if it is submitted. Since we are a small team with limited ability to provide maintenance and support, we may not accept all PRs (e.g., if we believe it would make the code harder to maintain, or if reviewing the PR is out of scope for us), so contacting us in advance is a good way to get a sense for whether your PR is likely to get approved for merging into openpi directly. But even if it isn't, you are of course more than welcome to maintain your own fork with whatever modifications you would like. When creating PRs, we recommend every contribution to consider the following:\n\n- Make sure that your PR has a clear title and description\n- Run `pre-commit` (install using `pre-commit install` first), and run `ruff check .` and `ruff format .`\n- Make sure your PR passes all tests\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License."
  },
  {
    "path": "LICENSE_GEMMA.txt",
    "content": "Gemma Terms of Use \n\nLast modified: February 21, 2024\n\nBy using, reproducing, modifying, distributing, performing or displaying any portion or element of Gemma, Model Derivatives including via any Hosted Service, (each as defined below) (collectively, the \"Gemma Services\") or otherwise accepting the terms of this Agreement, you agree to be bound by this Agreement.\n\nSection 1: DEFINITIONS\n1.1 Definitions\n(a) \"Agreement\" or \"Gemma Terms of Use\" means these terms and conditions that govern the use, reproduction, Distribution or modification of the Gemma Services and any terms and conditions incorporated by reference.\n\n(b) \"Distribution\" or \"Distribute\" means any transmission, publication, or other sharing of Gemma or Model Derivatives to a third party, including by providing or making Gemma or its functionality available as a hosted service via API, web access, or any other electronic or remote means (\"Hosted Service\").\n\n(c) \"Gemma\" means the set of machine learning language models, trained model weights and parameters identified at ai.google.dev/gemma, regardless of the source that you obtained it from.\n\n(d) \"Google\" means Google LLC.\n\n(e) \"Model Derivatives\" means all (i) modifications to Gemma, (ii) works based on Gemma, or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Gemma, to that model in order to cause that model to perform similarly to Gemma, including distillation methods that use intermediate data representations or methods based on the generation of synthetic data Outputs by Gemma for training that model. For clarity, Outputs are not deemed Model Derivatives.\n\n(f) \"Output\" means the information content output of Gemma or a Model Derivative that results from operating or otherwise using Gemma or the Model Derivative, including via a Hosted Service.\n\n1.2\nAs used in this Agreement, \"including\" means \"including without limitation\".\n\nSection 2: ELIGIBILITY AND USAGE\n2.1 Eligibility\nYou represent and warrant that you have the legal capacity to enter into this Agreement (including being of sufficient age of consent). If you are accessing or using any of the Gemma Services for or on behalf of a legal entity, (a) you are entering into this Agreement on behalf of yourself and that legal entity, (b) you represent and warrant that you have the authority to act on behalf of and bind that entity to this Agreement and (c) references to \"you\" or \"your\" in the remainder of this Agreement refers to both you (as an individual) and that entity.\n\n2.2 Use\nYou may use, reproduce, modify, Distribute, perform or display any of the Gemma Services only in accordance with the terms of this Agreement, and must not violate (or encourage or permit anyone else to violate) any term of this Agreement.\n\nSection 3: DISTRIBUTION AND RESTRICTIONS\n3.1 Distribution and Redistribution\nYou may reproduce or Distribute copies of Gemma or Model Derivatives if you meet all of the following conditions:\n\nYou must include the use restrictions referenced in Section 3.2 as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Gemma or Model Derivatives and you must provide notice to subsequent users you Distribute to that Gemma or Model Derivatives are subject to the use restrictions in Section 3.2.\nYou must provide all third party recipients of Gemma or Model Derivatives a copy of this Agreement.\nYou must cause any modified files to carry prominent notices stating that you modified the files.\nAll Distributions (other than through a Hosted Service) must be accompanied by a \"Notice\" text file that contains the following notice: \"Gemma is provided under and subject to the Gemma Terms of Use found at ai.google.dev/gemma/terms\".\nYou may add your own intellectual property statement to your modifications and, except as set forth in this Section, may provide additional or different terms and conditions for use, reproduction, or Distribution of your modifications, or for any such Model Derivatives as a whole, provided your use, reproduction, modification, Distribution, performance, and display of Gemma otherwise complies with the terms and conditions of this Agreement. Any additional or different terms and conditions you impose must not conflict with the terms of this Agreement.\n\n3.2 Use Restrictions\nYou must not use any of the Gemma Services:\n\nfor the restricted uses set forth in the Gemma Prohibited Use Policy at ai.google.dev/gemma/prohibited_use_policy (\"Prohibited Use Policy\"), which is hereby incorporated by reference into this Agreement; or\nin violation of applicable laws and regulations.\nTo the maximum extent permitted by law, Google reserves the right to restrict (remotely or otherwise) usage of any of the Gemma Services that Google reasonably believes are in violation of this Agreement.\n\n3.3 Generated Output\nGoogle claims no rights in Outputs you generate using Gemma. You and your users are solely responsible for Outputs and their subsequent uses.\n\nSection 4: ADDITIONAL PROVISIONS\n4.1 Updates\nGoogle may update Gemma from time to time, and you must make reasonable efforts to use the latest version of Gemma.\n\n4.2 Trademarks\nNothing in this Agreement grants you any rights to use Google's trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between you and Google. Google reserves any rights not expressly granted herein.\n\n4.3 DISCLAIMER OF WARRANTY\nUNLESS REQUIRED BY APPLICABLE LAW, THE GEMMA SERVICES, AND OUTPUTS, ARE PROVIDED ON AN \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY WARRANTIES OR CONDITIONS OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR OR DISTRIBUTING ANY OF THE GEMMA SERVICES OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR USE OR DISTRIBUTION OF ANY OF THE GEMMA SERVICES OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.\n\n4.4 LIMITATION OF LIABILITY\nTO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), PRODUCT LIABILITY, CONTRACT, OR OTHERWISE, UNLESS REQUIRED BY APPLICABLE LAW, SHALL GOOGLE OR ITS AFFILIATES BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL, OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO, ANY OF THE GEMMA SERVICES OR OUTPUTS EVEN IF GOOGLE OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.\n\n4.5 Term, Termination, and Survival\nThe term of this Agreement will commence upon your acceptance of this Agreement (including acceptance by your use, modification, or Distribution, reproduction, performance or display of any portion or element of the Gemma Services) and will continue in full force and effect until terminated in accordance with the terms of this Agreement. Google may terminate this Agreement if you are in breach of any term of this Agreement. Upon termination of this Agreement, you must delete and cease use and Distribution of all copies of Gemma and Model Derivatives in your possession or control. Sections 1, 2.1, 3.3, 4.2 to 4.9 shall survive the termination of this Agreement.\n\n4.6 Governing Law and Jurisdiction\nThis Agreement will be governed by the laws of the State of California without regard to choice of law principles. The UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The state and federal courts of Santa Clara County, California shall have exclusive jurisdiction of any dispute arising out of this Agreement.\n\n4.7 Severability\nIf any provision of this Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.\n\n4.8 Entire Agreement\nThis Agreement states all the terms agreed between the parties and supersedes all other agreements between the parties as of the date of acceptance relating to its subject matter.\n\n4.9 No Waiver\nGoogle will not be treated as having waived any rights by not exercising (or delaying the exercise of) any rights under this Agreement."
  },
  {
    "path": "README.md",
    "content": "# openpi\n\nopenpi holds open-source models and packages for robotics, published by the [Physical Intelligence team](https://www.physicalintelligence.company/).\n\nCurrently, this repo contains three types of models:\n- the [π₀ model](https://www.physicalintelligence.company/blog/pi0), a flow-based vision-language-action model (VLA).\n- the [π₀-FAST model](https://www.physicalintelligence.company/research/fast), an autoregressive VLA, based on the FAST action tokenizer.\n- the [π₀.₅ model](https://www.physicalintelligence.company/blog/pi05), an upgraded version of π₀ with better open-world generalization trained with [knowledge insulation](https://www.physicalintelligence.company/research/knowledge_insulation). Note that, in this repository, we currently only support the flow matching head for both $\\pi_{0.5}$ training and inference.\n\nFor all models, we provide _base model_ checkpoints, pre-trained on 10k+ hours of robot data, and examples for using them out of the box or fine-tuning them to your own datasets.\n\nThis is an experiment: $\\pi_0$ was developed for our own robots, which differ from the widely used platforms such as [ALOHA](https://tonyzhaozh.github.io/aloha/) and [DROID](https://droid-dataset.github.io/), and though we are optimistic that researchers and practitioners will be able to run creative new experiments adapting $\\pi_0$ to their own platforms, we do not expect every such attempt to be successful. All this is to say: $\\pi_0$ may or may not work for you, but you are welcome to try it and see!\n\n## Updates\n\n- [Sept 2025] We released PyTorch support in openpi.\n- [Sept 2025] We released pi05, an upgraded version of pi0 with better open-world generalization.\n- [Sept 2025]: We have added an [improved idle filter](examples/droid/README_train.md#data-filtering) for DROID training.\n- [Jun 2025]: We have added [instructions](examples/droid/README_train.md) for using `openpi` to train VLAs on the full [DROID dataset](https://droid-dataset.github.io/). This is an approximate open-source implementation of the training pipeline used to train pi0-FAST-DROID. \n\n\n## Requirements\n\nTo run the models in this repository, you will need an NVIDIA GPU with at least the following specifications. These estimations assume a single GPU, but you can also use multiple GPUs with model parallelism to reduce per-GPU memory requirements by configuring `fsdp_devices` in the training config. Please also note that the current training script does not yet support multi-node training.\n\n| Mode               | Memory Required | Example GPU        |\n| ------------------ | --------------- | ------------------ |\n| Inference          | > 8 GB          | RTX 4090           |\n| Fine-Tuning (LoRA) | > 22.5 GB       | RTX 4090           |\n| Fine-Tuning (Full) | > 70 GB         | A100 (80GB) / H100 |\n\nThe repo has been tested with Ubuntu 22.04, we do not currently support other operating systems.\n\n## Installation\n\nWhen cloning this repo, make sure to update submodules:\n\n```bash\ngit clone --recurse-submodules git@github.com:Physical-Intelligence/openpi.git\n\n# Or if you already cloned the repo:\ngit submodule update --init --recursive\n```\n\nWe use [uv](https://docs.astral.sh/uv/) to manage Python dependencies. See the [uv installation instructions](https://docs.astral.sh/uv/getting-started/installation/) to set it up. Once uv is installed, run the following to set up the environment:\n\n```bash\nGIT_LFS_SKIP_SMUDGE=1 uv sync\nGIT_LFS_SKIP_SMUDGE=1 uv pip install -e .\n```\n\nNOTE: `GIT_LFS_SKIP_SMUDGE=1` is needed to pull LeRobot as a dependency.\n\n**Docker**: As an alternative to uv installation, we provide instructions for installing openpi using Docker. If you encounter issues with your system setup, consider using Docker to simplify installation. See [Docker Setup](docs/docker.md) for more details.\n\n\n\n\n## Model Checkpoints\n\n### Base Models\nWe provide multiple base VLA model checkpoints. These checkpoints have been pre-trained on 10k+ hours of robot data, and can be used for fine-tuning.\n\n| Model        | Use Case    | Description                                                                                                 | Checkpoint Path                                |\n| ------------ | ----------- | ----------------------------------------------------------------------------------------------------------- | ---------------------------------------------- |\n| $\\pi_0$      | Fine-Tuning | Base [π₀ model](https://www.physicalintelligence.company/blog/pi0) for fine-tuning                | `gs://openpi-assets/checkpoints/pi0_base`      |\n| $\\pi_0$-FAST | Fine-Tuning | Base autoregressive [π₀-FAST model](https://www.physicalintelligence.company/research/fast) for fine-tuning | `gs://openpi-assets/checkpoints/pi0_fast_base` |\n| $\\pi_{0.5}$    | Fine-Tuning | Base [π₀.₅ model](https://www.physicalintelligence.company/blog/pi05) for fine-tuning    | `gs://openpi-assets/checkpoints/pi05_base`      |\n\n### Fine-Tuned Models\nWe also provide \"expert\" checkpoints for various robot platforms and tasks. These models are fine-tuned from the base models above and intended to run directly on the target robot. These may or may not work on your particular robot. Since these checkpoints were fine-tuned on relatively small datasets collected with more widely available robots, such as ALOHA and the DROID Franka setup, they might not generalize to your particular setup, though we found some of these, especially the DROID checkpoint, to generalize quite broadly in practice.\n\n| Model                    | Use Case    | Description                                                                                                                                                                                              | Checkpoint Path                                       |\n| ------------------------ | ----------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------- |\n| $\\pi_0$-FAST-DROID       | Inference   | $\\pi_0$-FAST model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/): can perform a wide range of simple table-top manipulation tasks 0-shot in new scenes on the DROID robot platform | `gs://openpi-assets/checkpoints/pi0_fast_droid`       |\n| $\\pi_0$-DROID            | Fine-Tuning | $\\pi_0$ model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/): faster inference than $\\pi_0$-FAST-DROID, but may not follow language commands as well                                | `gs://openpi-assets/checkpoints/pi0_droid`            |\n| $\\pi_0$-ALOHA-towel      | Inference   | $\\pi_0$ model fine-tuned on internal [ALOHA](https://tonyzhaozh.github.io/aloha/) data: can fold diverse towels 0-shot on ALOHA robot platforms                                                          | `gs://openpi-assets/checkpoints/pi0_aloha_towel`      |\n| $\\pi_0$-ALOHA-tupperware | Inference   | $\\pi_0$ model fine-tuned on internal [ALOHA](https://tonyzhaozh.github.io/aloha/) data: can unpack food from a tupperware container                                                                                                             | `gs://openpi-assets/checkpoints/pi0_aloha_tupperware` |\n| $\\pi_0$-ALOHA-pen-uncap  | Inference   | $\\pi_0$ model fine-tuned on public [ALOHA](https://dit-policy.github.io/) data: can uncap a pen                                                                                                          | `gs://openpi-assets/checkpoints/pi0_aloha_pen_uncap`  |\n| $\\pi_{0.5}$-LIBERO      | Inference   | $\\pi_{0.5}$ model fine-tuned for the [LIBERO](https://libero-project.github.io/datasets) benchmark: gets state-of-the-art performance (see [LIBERO README](examples/libero/README.md)) | `gs://openpi-assets/checkpoints/pi05_libero`      |\n| $\\pi_{0.5}$-DROID      | Inference / Fine-Tuning | $\\pi_{0.5}$ model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/) with [knowledge insulation](https://www.physicalintelligence.company/research/knowledge_insulation): fast inference and good language-following | `gs://openpi-assets/checkpoints/pi05_droid`      |\n\n\nBy default, checkpoints are automatically downloaded from `gs://openpi-assets` and are cached in `~/.cache/openpi` when needed. You can overwrite the download path by setting the `OPENPI_DATA_HOME` environment variable.\n\n\n\n\n## Running Inference for a Pre-Trained Model\n\nOur pre-trained model checkpoints can be run with a few lines of code (here our $\\pi_0$-FAST-DROID model):\n```python\nfrom openpi.training import config as _config\nfrom openpi.policies import policy_config\nfrom openpi.shared import download\n\nconfig = _config.get_config(\"pi05_droid\")\ncheckpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi05_droid\")\n\n# Create a trained policy.\npolicy = policy_config.create_trained_policy(config, checkpoint_dir)\n\n# Run inference on a dummy example.\nexample = {\n    \"observation/exterior_image_1_left\": ...,\n    \"observation/wrist_image_left\": ...,\n    ...\n    \"prompt\": \"pick up the fork\"\n}\naction_chunk = policy.infer(example)[\"actions\"]\n```\nYou can also test this out in the [example notebook](examples/inference.ipynb).\n\nWe provide detailed step-by-step examples for running inference of our pre-trained checkpoints on [DROID](examples/droid/README.md) and [ALOHA](examples/aloha_real/README.md) robots.\n\n**Remote Inference**: We provide [examples and code](docs/remote_inference.md) for running inference of our models **remotely**: the model can run on a different server and stream actions to the robot via a websocket connection. This makes it easy to use more powerful GPUs off-robot and keep robot and policy environments separate.\n\n**Test inference without a robot**: We provide a [script](examples/simple_client/README.md) for testing inference without a robot. This script will generate a random observation and run inference with the model. See [here](examples/simple_client/README.md) for more details.\n\n\n\n\n\n## Fine-Tuning Base Models on Your Own Data\n\nWe will fine-tune the $\\pi_{0.5}$ model on the [LIBERO dataset](https://libero-project.github.io/datasets) as a running example for how to fine-tune a base model on your own data. We will explain three steps:\n1. Convert your data to a LeRobot dataset (which we use for training)\n2. Defining training configs and running training\n3. Spinning up a policy server and running inference\n\n### 1. Convert your data to a LeRobot dataset\n\nWe provide a minimal example script for converting LIBERO data to a LeRobot dataset in [`examples/libero/convert_libero_data_to_lerobot.py`](examples/libero/convert_libero_data_to_lerobot.py). You can easily modify it to convert your own data! You can download the raw LIBERO dataset from [here](https://huggingface.co/datasets/openvla/modified_libero_rlds), and run the script with:\n\n```bash\nuv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/libero/data\n```\n\n**Note:** If you just want to fine-tune on LIBERO, you can skip this step, because our LIBERO fine-tuning configs point to a pre-converted LIBERO dataset. This step is merely an example that you can adapt to your own data.\n\n### 2. Defining training configs and running training\n\nTo fine-tune a base model on your own data, you need to define configs for data processing and training. We provide example configs with detailed comments for LIBERO below, which you can modify for your own dataset:\n\n- [`LiberoInputs` and `LiberoOutputs`](src/openpi/policies/libero_policy.py): Defines the data mapping from the LIBERO environment to the model and vice versa. Will be used for both, training and inference.\n- [`LeRobotLiberoDataConfig`](src/openpi/training/config.py): Defines how to process raw LIBERO data from LeRobot dataset for training.\n- [`TrainConfig`](src/openpi/training/config.py): Defines fine-tuning hyperparameters, data config, and weight loader.\n\nWe provide example fine-tuning configs for [π₀](src/openpi/training/config.py), [π₀-FAST](src/openpi/training/config.py), and [π₀.₅](src/openpi/training/config.py) on LIBERO data.\n\nBefore we can run training, we need to compute the normalization statistics for the training data. Run the script below with the name of your training config:\n\n```bash\nuv run scripts/compute_norm_stats.py --config-name pi05_libero\n```\n\nNow we can kick off training with the following command (the `--overwrite` flag is used to overwrite existing checkpoints if you rerun fine-tuning with the same config):\n\n```bash\nXLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_libero --exp-name=my_experiment --overwrite\n```\n\nThe command will log training progress to the console and save checkpoints to the `checkpoints` directory. You can also monitor training progress on the Weights & Biases dashboard. For maximally using the GPU memory, set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` before running training -- this enables JAX to use up to 90% of the GPU memory (vs. the default of 75%).\n\n**Note:** We provide functionality for *reloading* normalization statistics for state / action normalization from pre-training. This can be beneficial if you are fine-tuning to a new task on a robot that was part of our pre-training mixture. For more details on how to reload normalization statistics, see the [norm_stats.md](docs/norm_stats.md) file.\n\n### 3. Spinning up a policy server and running inference\n\nOnce training is complete, we can run inference by spinning up a policy server and then querying it from a LIBERO evaluation script. Launching a model server is easy (we use the checkpoint for iteration 20,000 for this example, modify as needed):\n\n```bash\nuv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_libero --policy.dir=checkpoints/pi05_libero/my_experiment/20000\n```\n\nThis will spin up a server that listens on port 8000 and waits for observations to be sent to it. We can then run an evaluation script (or robot runtime) that queries the server.\n\nFor running the LIBERO eval in particular, we provide (and recommend using) a Dockerized workflow that handles both the policy server and the evaluation script together. See the [LIBERO README](examples/libero/README.md) for more details.\n\nIf you want to embed a policy server call in your own robot runtime, we have a minimal example of how to do so in the [remote inference docs](docs/remote_inference.md).\n\n\n\n### More Examples\n\nWe provide more examples for how to fine-tune and run inference with our models on the ALOHA platform in the following READMEs:\n- [ALOHA Simulator](examples/aloha_sim)\n- [ALOHA Real](examples/aloha_real)\n- [UR5](examples/ur5)\n\n## PyTorch Support\n\nopenpi now provides PyTorch implementations of π₀ and π₀.₅ models alongside the original JAX versions! The PyTorch implementation has been validated on the LIBERO benchmark (both inference and finetuning). A few features are currently not supported (this may change in the future):\n\n- The π₀-FAST model\n- Mixed precision training\n- FSDP (fully-sharded data parallelism) training\n- LoRA (low-rank adaptation) training\n- EMA (exponential moving average) weights during training\n\n### Setup\n1. Make sure that you have the latest version of all dependencies installed: `uv sync`\n\n2. Double check that you have transformers 4.53.2 installed: `uv pip show transformers`\n\n3. Apply the transformers library patches:\n   ```bash\n   cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/\n   ```\n\nThis overwrites several files in the transformers library with necessary model changes: 1) supporting AdaRMS, 2) correctly controlling the precision of activations, and 3) allowing the KV cache to be used without being updated.\n\n**WARNING**: With the default uv link mode (hardlink), this will permanently affect the transformers library in your uv cache, meaning the changes will survive reinstallations of transformers and could even propagate to other projects that use transformers. To fully undo this operation, you must run `uv cache clean transformers`.\n\n### Converting JAX Models to PyTorch\n\nTo convert a JAX model checkpoint to PyTorch format:\n\n```bash\nuv run examples/convert_jax_model_to_pytorch.py \\\n    --checkpoint_dir /path/to/jax/checkpoint \\\n    --config_name <config name> \\\n    --output_path /path/to/converted/pytorch/checkpoint\n```\n\n### Running Inference with PyTorch\n\nThe PyTorch implementation uses the same API as the JAX version - you only need to change the checkpoint path to point to the converted PyTorch model:\n\n```python\nfrom openpi.training import config as _config\nfrom openpi.policies import policy_config\nfrom openpi.shared import download\n\nconfig = _config.get_config(\"pi05_droid\")\ncheckpoint_dir = \"/path/to/converted/pytorch/checkpoint\"\n\n# Create a trained policy (automatically detects PyTorch format)\npolicy = policy_config.create_trained_policy(config, checkpoint_dir)\n\n# Run inference (same API as JAX)\naction_chunk = policy.infer(example)[\"actions\"]\n```\n\n### Policy Server with PyTorch\n\nThe policy server works identically with PyTorch models - just point to the converted checkpoint directory:\n\n```bash\nuv run scripts/serve_policy.py policy:checkpoint \\\n    --policy.config=pi05_droid \\\n    --policy.dir=/path/to/converted/pytorch/checkpoint\n```\n\n### Finetuning with PyTorch\n\nTo finetune a model in PyTorch:\n\n1. Convert the JAX base model to PyTorch format:\n   ```bash\n   uv run examples/convert_jax_model_to_pytorch.py \\\n       --config_name <config name> \\\n       --checkpoint_dir /path/to/jax/base/model \\\n       --output_path /path/to/pytorch/base/model\n   ```\n\n2. Specify the converted PyTorch model path in your config using `pytorch_weight_path`\n\n3. Launch training using one of these modes:\n\n```bash\n# Single GPU training:\nuv run scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>\n\n# Example:\nuv run scripts/train_pytorch.py debug --exp_name pytorch_test\nuv run scripts/train_pytorch.py debug --exp_name pytorch_test --resume  # Resume from latest checkpoint\n\n# Multi-GPU training (single node):\nuv run torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name>\n\n# Example:\nuv run torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test\nuv run torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume\n\n# Multi-Node Training:\nuv run torchrun \\\n    --nnodes=<num_nodes> \\\n    --nproc_per_node=<gpus_per_node> \\\n    --node_rank=<rank_of_node> \\\n    --master_addr=<master_ip> \\\n    --master_port=<port> \\\n    scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>\n```\n\n### Precision Settings\n\nJAX and PyTorch implementations handle precision as follows:\n\n**JAX:**\n1. Inference: most weights and computations in bfloat16, with a few computations in float32 for stability\n2. Training: defaults to mixed precision: weights and gradients in float32, (most) activations and computations in bfloat16. You can change to full float32 training by setting `dtype` to float32 in the config.\n\n**PyTorch:**\n1. Inference: matches JAX -- most weights and computations in bfloat16, with a few weights converted to float32 for stability\n2. Training: supports either full bfloat16 (default) or full float32. You can change it by setting `pytorch_training_precision` in the config. bfloat16 uses less memory but exhibits higher losses compared to float32. Mixed precision is not yet supported.\n\nWith torch.compile, inference speed is comparable between JAX and PyTorch.\n\n## Troubleshooting\n\nWe will collect common issues and their solutions here. If you encounter an issue, please check here first. If you can't find a solution, please file an issue on the repo (see [here](CONTRIBUTING.md) for guidelines).\n\n| Issue                                     | Resolution                                                                                                                                                                                   |\n| ----------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |\n| `uv sync` fails with dependency conflicts | Try removing the virtual environment directory (`rm -rf .venv`) and running `uv sync` again. If issues persist, check that you have the latest version of `uv` installed (`uv self update`). |\n| Training runs out of GPU memory           | Make sure you set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` (or higher) before running training to allow JAX to use more GPU memory. You can also use `--fsdp-devices <n>` where `<n>` is your number of GPUs, to enable [fully-sharded data parallelism](https://engineering.fb.com/2021/07/15/open-source/fsdp/), which reduces memory usage in exchange for slower training (the amount of slowdown depends on your particular setup). If you are still running out of memory, you may want to consider disabling EMA.        |\n| Policy server connection errors           | Check that the server is running and listening on the expected port. Verify network connectivity and firewall settings between client and server.                                            |\n| Missing norm stats error when training    | Run `scripts/compute_norm_stats.py` with your config name before starting training.                                                                                                          |\n| Dataset download fails                    | Check your internet connection. For HuggingFace datasets, ensure you're logged in (`huggingface-cli login`).                                                                                 |\n| CUDA/GPU errors                           | Verify NVIDIA drivers are installed correctly. For Docker, ensure nvidia-container-toolkit is installed. Check GPU compatibility. You do NOT need CUDA libraries installed at a system level --- they will be installed via uv. You may even want to try *uninstalling* system CUDA libraries if you run into CUDA issues, since system libraries can sometimes cause conflicts. |\n| Import errors when running examples       | Make sure you've installed all dependencies with `uv sync`. Some examples may have additional requirements listed in their READMEs.                    |\n| Action dimensions mismatch                | Verify your data processing transforms match the expected input/output dimensions of your robot. Check the action space definitions in your policy classes.                                  |\n| Diverging training loss                            | Check the `q01`, `q99`, and `std` values in `norm_stats.json` for your dataset. Certain dimensions that are rarely used can end up with very small `q01`, `q99`, or `std` values, leading to huge states and actions after normalization. You can manually adjust the norm stats as a workaround. |\n"
  },
  {
    "path": "docs/docker.md",
    "content": "### Docker Setup\n\nAll of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS.\n\n- Basic Docker installation instructions are [here](https://docs.docker.com/engine/install/).\n- Docker must be installed in [rootless mode](https://docs.docker.com/engine/security/rootless/).\n- To use your GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).\n- The version of docker installed with `snap` is incompatible with the NVIDIA container toolkit, preventing it from accessing `libnvidia-ml.so` ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/154)). The snap version can be uninstalled with `sudo snap remove docker`.\n- Docker Desktop is also incompatible with the NVIDIA runtime ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/229)). Docker Desktop can be uninstalled with `sudo apt remove docker-desktop`.\n\n\nIf starting from scratch and your host machine is Ubuntu 22.04, you can use accomplish all of the above with the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`.\n\nBuild the Docker image and start the container with the following command:\n```bash\ndocker compose -f scripts/docker/compose.yml up --build\n```\n\nTo build and run the Docker image for a specific example, use the following command:\n```bash\ndocker compose -f examples/<example_name>/compose.yml up --build\n```\nwhere `<example_name>` is the name of the example you want to run.\n\nDuring the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached."
  },
  {
    "path": "docs/norm_stats.md",
    "content": "# Normalization statistics\n\nFollowing common practice, our models normalize the proprioceptive state inputs and action targets during policy training and inference. The statistics used for normalization are computed over the training data and stored alongside the model checkpoint.\n\n## Reloading normalization statistics\n\nWhen you fine-tune one of our models on a new dataset, you need to decide whether to (A) reuse existing normalization statistics or (B) compute new statistics over your new training data. Which option is better for you depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. Below, we list all the available pre-training normalization statistics for each model.\n\n**If your target robot matches one of these pre-training statistics, consider reloading the same normalization statistics.** By reloading the normalization statistics, the actions in your dataset will be more \"familiar\" to the model, which can lead to better performance. You can reload the normalization statistics by adding an `AssetsConfig` to your training config that points to the corresponding checkpoint directory and normalization statistics ID, like below for the `Trossen` (aka ALOHA) robot statistics of the `pi0_base` checkpoint:\n\n```python\nTrainConfig(\n    ...\n    data=LeRobotAlohaDataConfig(\n        ...\n        assets=AssetsConfig(\n            assets_dir=\"gs://openpi-assets/checkpoints/pi0_base/assets\",\n            asset_id=\"trossen\",\n        ),\n    ),\n)\n```\n\nFor an example of a full training config that reloads normalization statistics, see the `pi0_aloha_pen_uncap` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).\n\n**Note:** To successfully reload normalization statistics, it's important that your robot + dataset are following the action space definitions used in pre-training. We provide a detailed description of our action space definitions below.\n\n**Note #2:** Whether reloading normalization statistics is beneficial depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. We recommend to always try both, reloading and training with a fresh set of statistics computed on your new dataset (see [main README](../README.md) for instructions on how to compute new statistics), and pick the one that works better for your task.\n\n\n## Provided Pre-training Normalization Statistics\n\nBelow is a list of all the pre-training normalization statistics we provide. We provide them for both, the `pi0_base` and `pi0_fast_base` models. For `pi0_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_base/assets` and for `pi0_fast_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_fast_base/assets`.\n| Robot | Description | Asset ID |\n|-------|-------------|----------|\n| ALOHA | 6-DoF dual arm robot with parallel grippers | trossen |\n| Mobile ALOHA | Mobile version of ALOHA mounted on a Slate base | trossen_mobile |\n| Franka Emika (DROID) | 7-DoF arm with parallel gripper based on the DROID setup | droid |\n| Franka Emika (non-DROID) | Franka FR3 arm with Robotiq 2F-85 gripper | franka |\n| UR5e | 6-DoF UR5e arm with Robotiq 2F-85 gripper | ur5e |\n| UR5e bi-manual | Bi-manual UR5e setup with Robotiq 2F-85 grippers | ur5e_dual |\n| ARX | Bi-manual ARX-5 robot arm setup with parallel gripper | arx |\n| ARX mobile | Mobile version of bi-manual ARX-5 robot arm setup mounted on a Slate base | arx_mobile |\n| Fibocom mobile | Fibocom mobile robot with 2x ARX-5 arms | fibocom_mobile |\n\n\n## Pi0 Model Action Space Definitions\n\nOut of the box, both the `pi0_base` and `pi0_fast_base` use the following action space definitions (left and right are defined looking from behind the robot towards the workspace):\n```\n    \"dim_0:dim_5\": \"left arm joint angles\",\n    \"dim_6\": \"left arm gripper position\",\n    \"dim_7:dim_12\": \"right arm joint angles (for bi-manual only)\",\n    \"dim_13\": \"right arm gripper position (for bi-manual only)\",\n\n    # For mobile robots:\n    \"dim_14:dim_15\": \"x-y base velocity (for mobile robots only)\",\n```\n\nThe proprioceptive state uses the same definitions as the action space, except for the base x-y position (the last two dimensions) for mobile robots, which we don't include in the proprioceptive state.\n\nFor 7-DoF robots (e.g. Franka), we use the first 7 dimensions of the action space for the joint actions, and the 8th dimension for the gripper action.\n\nGeneral info for Pi robots:\n- Joint angles are expressed in radians, with position zero corresponding to the zero position reported by each robot's interface library, except for ALOHA, where the standard ALOHA code uses a slightly different convention (see the [ALOHA example code](../examples/aloha_real/README.md) for details).\n- Gripper positions are in [0.0, 1.0], with 0.0 corresponding to fully open and 1.0 corresponding to fully closed.\n- Control frequencies are either 20 Hz for UR5e and Franka, and 50 Hz for ARX and Trossen (ALOHA) arms.\n\nFor DROID, we use the original DROID action configuration, with joint velocity actions in the first 7 dimensions and gripper actions in the 8th dimension + a control frequency of 15 Hz.\n"
  },
  {
    "path": "docs/remote_inference.md",
    "content": "\n# Running openpi models remotely\n\nWe provide utilities for running openpi models remotely. This is useful for running inference on more powerful GPUs off-robot, and also helps keep the robot and policy environments separate (and e.g. avoid dependency hell with robot software).\n\n## Starting a remote policy server\n\nTo start a remote policy server, you can simply run the following command:\n\n```bash\nuv run scripts/serve_policy.py --env=[DROID | ALOHA | LIBERO]\n```\n\nThe `env` argument specifies which $\\pi_0$ checkpoint should be loaded. Under the hood, this script will execute a command like the following, which you can use to start a policy server, e.g. for checkpoints you trained yourself (here an example for the DROID environment):\n\n```bash\nuv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid\n```\n\nThis will start a policy server that will serve the policy specified by the `config` and `dir` arguments. The policy will be served on the specified port (default: 8000).\n\n## Querying the remote policy server from your robot code\n\nWe provide a client utility with minimal dependencies that you can easily embed into any robot codebase.\n\nFirst, install the `openpi-client` package in your robot environment:\n\n```bash\ncd $OPENPI_ROOT/packages/openpi-client\npip install -e .\n```\n\nThen, you can use the client to query the remote policy server from your robot code. Here's an example of how to do this:\n\n```python\nfrom openpi_client import image_tools\nfrom openpi_client import websocket_client_policy\n\n# Outside of episode loop, initialize the policy client.\n# Point to the host and port of the policy server (localhost and 8000 are the defaults).\nclient = websocket_client_policy.WebsocketClientPolicy(host=\"localhost\", port=8000)\n\nfor step in range(num_steps):\n    # Inside the episode loop, construct the observation.\n    # Resize images on the client side to minimize bandwidth / latency. Always return images in uint8 format.\n    # We provide utilities for resizing images + uint8 conversion so you match the training routines.\n    # The typical resize_size for pre-trained pi0 models is 224.\n    # Note that the proprioceptive `state` can be passed unnormalized, normalization will be handled on the server side.\n    observation = {\n        \"observation/image\": image_tools.convert_to_uint8(\n            image_tools.resize_with_pad(img, 224, 224)\n        ),\n        \"observation/wrist_image\": image_tools.convert_to_uint8(\n            image_tools.resize_with_pad(wrist_img, 224, 224)\n        ),\n        \"observation/state\": state,\n        \"prompt\": task_instruction,\n    }\n\n    # Call the policy server with the current observation.\n    # This returns an action chunk of shape (action_horizon, action_dim).\n    # Note that you typically only need to call the policy every N steps and execute steps\n    # from the predicted action chunk open-loop in the remaining steps.\n    action_chunk = client.infer(observation)[\"actions\"]\n\n    # Execute the actions in the environment.\n    ...\n\n```\n\nHere, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `observation` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](../examples/simple_client/main.py).\n"
  },
  {
    "path": "examples/aloha_real/Dockerfile",
    "content": "# Dockerfile for the Aloha real environment.\n\n# Build the container:\n# docker build . -t aloha_real -f examples/aloha_real/Dockerfile\n\n# Run the container:\n# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash\n\nFROM ros:noetic-robot@sha256:7cf0b9f6546abeba308ea42cb7ad3453f3e520e1af57cdf179fe915c939674bc\nSHELL [\"/bin/bash\", \"-c\"]\n\nENV DEBIAN_FRONTEND=noninteractive\nRUN apt-get update && \\\n    apt-get install -y --no-install-recommends \\\n    cmake \\\n    curl \\\n    libffi-dev \\\n    python3-rosdep \\\n    python3-rosinstall \\\n    python3-rosinstall-generator \\\n    whiptail \\\n    git \\\n    wget \\\n    openssh-client \\\n    ros-noetic-cv-bridge \\\n    ros-noetic-usb-cam \\\n    ros-noetic-realsense2-camera \\\n    keyboard-configuration\n\nWORKDIR /root\nRUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh\nRUN chmod +x xsarm_amd64_install.sh\nRUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n\n\nCOPY ./third_party/aloha /root/interbotix_ws/src/aloha\nRUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make\n\n# Install python 3.10 because this ROS image comes with 3.8\nRUN mkdir /python && \\\n    cd /python && \\\n    wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \\\n    tar -zxvf Python-3.10.14.tgz && \\\n    cd Python-3.10.14 && \\\n    ls -lhR && \\\n    ./configure --enable-optimizations && \\\n    make install && \\\n    echo 'alias python3=\"/usr/local/bin/python3.10\"' >> ~/.bashrc && \\\n    echo 'alias python=\"/usr/local/bin/python3.10\"' >> ~/.bashrc && \\\n    cd ~ && rm -rf /python && \\\n    rm -rf /var/lib/apt/lists/*\n\nCOPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv\nENV UV_HTTP_TIMEOUT=120\nENV UV_LINK_MODE=copy\nCOPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt\nCOPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml\nRUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml\n\nENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha\nWORKDIR /app\n\n# Create an entrypoint script to run the setup commands, followed by the command passed in.\nRUN cat <<'EOF' > /usr/local/bin/entrypoint.sh\n#!/bin/bash\nsource /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && \"$@\"\nEOF\nRUN chmod +x /usr/local/bin/entrypoint.sh\n\nENTRYPOINT [\"/usr/local/bin/entrypoint.sh\"]\nCMD [\"python3\", \"/app/examples/aloha_real/main.py\"]\n"
  },
  {
    "path": "examples/aloha_real/README.md",
    "content": "# Run Aloha (Real Robot)\n\nThis example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.\n\n## Prerequisites\n\nThis repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.\n\n1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.\n1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.\n\n## With Docker\n\n```bash\nexport SERVER_ARGS=\"--env ALOHA --default_prompt='take the toast out of the toaster'\"\ndocker compose -f examples/aloha_real/compose.yml up --build\n```\n\n## Without Docker\n\nTerminal window 1:\n\n```bash\n# Create virtual environment\nuv venv --python 3.10 examples/aloha_real/.venv\nsource examples/aloha_real/.venv/bin/activate\nuv pip sync examples/aloha_real/requirements.txt\nuv pip install -e packages/openpi-client\n\n# Run the robot\npython -m examples.aloha_real.main\n```\n\nTerminal window 2:\n\n```bash\nroslaunch aloha ros_nodes.launch\n```\n\nTerminal window 3:\n\n```bash\nuv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster'\n```\n\n## **ALOHA Checkpoint Guide**\n\n\nThe `pi0_base` model can be used in zero shot for a simple task on the ALOHA platform, and we additionally provide two example fine-tuned checkpoints, “fold the towel” and “open the tupperware and put the food on the plate,” which can perform more advanced tasks on the ALOHA.\n\nWhile we’ve found the policies to work in unseen conditions across multiple ALOHA stations, we provide some pointers here on how best to set up scenes to maximize the chance of policy success. We cover the prompts to use for the policies, objects we’ve seen it work well on, and well-represented initial state distributions. Running these policies in zero shot is still a very experimental feature, and there is no guarantee that they will work on your robot. The recommended way to use `pi0_base` is by finetuning with data from the target robot.\n\n\n---\n\n### **Toast Task**\n\nThis task involves the robot taking two pieces of toast out of a toaster and placing them on a plate.\n\n- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_base`\n- **Prompt**: \"take the toast out of the toaster\"\n- **Objects needed**: Two pieces of toast, a plate, and a standard toaster.\n- **Object Distribution**:\n  - Works on both real toast and rubber fake toast\n  - Compatible with standard 2-slice toasters\n  - Works with plates of varying colors\n\n### **Scene Setup Guidelines**\n<img width=\"500\" alt=\"Screenshot 2025-01-31 at 10 06 02 PM\" src=\"https://github.com/user-attachments/assets/3d043d95-9d1c-4dda-9991-e63cae61e02e\" />\n\n- The toaster should be positioned in the top-left quadrant of the workspace.\n- Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top.\n- The plate should be placed roughly in the lower-center of the workspace.\n- Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain).\n\n\n### **Towel Task**\n\nThis task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths.\n\n- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_towel`\n- **Prompt**: \"fold the towel\"\n- **Object Distribution**:\n  - Works on towels of varying solid colors\n  - Performance is worse on heavily textured or striped towels\n\n### **Scene Setup Guidelines**\n<img width=\"500\" alt=\"Screenshot 2025-01-31 at 10 01 15 PM\" src=\"https://github.com/user-attachments/assets/9410090c-467d-4a9c-ac76-96e5b4d00943\" />\n\n- The towel should be flattened and roughly centered on the table.\n- Choose a towel that does not blend in with the table surface.\n\n\n### **Tupperware Task**\n\nThis task involves opening a tupperware filled with food and pouring the contents onto a plate.\n\n- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_tupperware`\n- **Prompt**: \"open the tupperware and put the food on the plate\"\n- **Objects needed**: Tupperware, food (or food-like items), and a plate.\n- **Object Distribution**:\n  - Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken).\n  - Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below).\n  - The policy has seen plates of varying solid colors.\n\n### **Scene Setup Guidelines**\n<img width=\"500\" alt=\"Screenshot 2025-01-31 at 10 02 27 PM\" src=\"https://github.com/user-attachments/assets/60fc1de0-2d64-4076-b903-f427e5e9d1bf\" />\n\n- Best performance observed when both the tupperware and plate are roughly centered in the workspace.\n- Positioning:\n  - Tupperware should be on the left.\n  - Plate should be on the right or bottom.\n  - The tupperware flap should point toward the plate.\n\n## Training on your own Aloha dataset\n\n1. Convert the dataset to the LeRobot dataset v2.0 format.\n\n    We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse).\n\n\n2. Define a training config that uses the custom dataset.\n\n    We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.\n\nIMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig.\n"
  },
  {
    "path": "examples/aloha_real/compose.yml",
    "content": "# Run with:\n# docker compose -f examples/aloha_real/compose.yml up --build\nservices:\n  runtime:\n    image: aloha_real\n    depends_on:\n      - aloha_ros_nodes\n      - ros_master\n      - openpi_server\n    build:\n      context: ../..\n      dockerfile: examples/aloha_real/Dockerfile\n    init: true\n    tty: true\n    network_mode: host\n    privileged: true\n    volumes:\n      - $PWD:/app\n      - ../../data:/data\n\n  aloha_ros_nodes:\n    image: aloha_real\n    depends_on:\n      - ros_master\n    build:\n      context: ../..\n      dockerfile: examples/aloha_real/Dockerfile\n    init: true\n    tty: true\n    network_mode: host\n    privileged: true\n    volumes:\n      - /dev:/dev\n    command: roslaunch --wait aloha ros_nodes.launch\n\n  ros_master:\n    image: ros:noetic-robot\n    network_mode: host\n    privileged: true\n    command:\n      - roscore\n\n  openpi_server:\n    image: openpi_server\n    build:\n      context: ../..\n      dockerfile: scripts/docker/serve_policy.Dockerfile\n    init: true\n    tty: true\n    network_mode: host\n    volumes:\n      - $PWD:/app\n      - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets\n    environment:\n      - SERVER_ARGS\n      - OPENPI_DATA_HOME=/openpi_assets\n      - IS_DOCKER=true\n\n    # Comment out this block if not running on a machine with GPUs.\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              count: 1\n              capabilities: [gpu]\n"
  },
  {
    "path": "examples/aloha_real/constants.py",
    "content": "# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).\n# ruff: noqa\n\n### Task parameters\n\n### ALOHA fixed constants\nDT = 0.001\nJOINT_NAMES = [\"waist\", \"shoulder\", \"elbow\", \"forearm_roll\", \"wrist_angle\", \"wrist_rotate\"]\nSTART_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]\n\n# Left finger position limits (qpos[7]), right_finger = -1 * left_finger\nMASTER_GRIPPER_POSITION_OPEN = 0.02417\nMASTER_GRIPPER_POSITION_CLOSE = 0.01244\nPUPPET_GRIPPER_POSITION_OPEN = 0.05800\nPUPPET_GRIPPER_POSITION_CLOSE = 0.01844\n\n# Gripper joint limits (qpos[6])\nMASTER_GRIPPER_JOINT_OPEN = 0.3083\nMASTER_GRIPPER_JOINT_CLOSE = -0.6842\nPUPPET_GRIPPER_JOINT_OPEN = 1.4910\nPUPPET_GRIPPER_JOINT_CLOSE = -0.6213\n\n############################ Helper functions ############################\n\nMASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (\n    MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE\n)\nPUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (\n    PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE\n)\nMASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (\n    lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE\n)\nPUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (\n    lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE\n)\nMASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))\n\nMASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (\n    MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE\n)\nPUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (\n    PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE\n)\nMASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (\n    lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE\n)\nPUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (\n    lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE\n)\nMASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))\n\nMASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)\nPUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)\n\nMASTER_POS2JOINT = (\n    lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)\n    + MASTER_GRIPPER_JOINT_CLOSE\n)\nMASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(\n    (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)\n)\nPUPPET_POS2JOINT = (\n    lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)\n    + PUPPET_GRIPPER_JOINT_CLOSE\n)\nPUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(\n    (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)\n)\n\nMASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2\n"
  },
  {
    "path": "examples/aloha_real/convert_aloha_data_to_lerobot.py",
    "content": "\"\"\"\nScript to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.\n\nExample usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>\n\"\"\"\n\nimport dataclasses\nfrom pathlib import Path\nimport shutil\nfrom typing import Literal\n\nimport h5py\nfrom lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME\nfrom lerobot.common.datasets.lerobot_dataset import LeRobotDataset\nfrom lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw\nimport numpy as np\nimport torch\nimport tqdm\nimport tyro\n\n\n@dataclasses.dataclass(frozen=True)\nclass DatasetConfig:\n    use_videos: bool = True\n    tolerance_s: float = 0.0001\n    image_writer_processes: int = 10\n    image_writer_threads: int = 5\n    video_backend: str | None = None\n\n\nDEFAULT_DATASET_CONFIG = DatasetConfig()\n\n\ndef create_empty_dataset(\n    repo_id: str,\n    robot_type: str,\n    mode: Literal[\"video\", \"image\"] = \"video\",\n    *,\n    has_velocity: bool = False,\n    has_effort: bool = False,\n    dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,\n) -> LeRobotDataset:\n    motors = [\n        \"right_waist\",\n        \"right_shoulder\",\n        \"right_elbow\",\n        \"right_forearm_roll\",\n        \"right_wrist_angle\",\n        \"right_wrist_rotate\",\n        \"right_gripper\",\n        \"left_waist\",\n        \"left_shoulder\",\n        \"left_elbow\",\n        \"left_forearm_roll\",\n        \"left_wrist_angle\",\n        \"left_wrist_rotate\",\n        \"left_gripper\",\n    ]\n    cameras = [\n        \"cam_high\",\n        \"cam_low\",\n        \"cam_left_wrist\",\n        \"cam_right_wrist\",\n    ]\n\n    features = {\n        \"observation.state\": {\n            \"dtype\": \"float32\",\n            \"shape\": (len(motors),),\n            \"names\": [\n                motors,\n            ],\n        },\n        \"action\": {\n            \"dtype\": \"float32\",\n            \"shape\": (len(motors),),\n            \"names\": [\n                motors,\n            ],\n        },\n    }\n\n    if has_velocity:\n        features[\"observation.velocity\"] = {\n            \"dtype\": \"float32\",\n            \"shape\": (len(motors),),\n            \"names\": [\n                motors,\n            ],\n        }\n\n    if has_effort:\n        features[\"observation.effort\"] = {\n            \"dtype\": \"float32\",\n            \"shape\": (len(motors),),\n            \"names\": [\n                motors,\n            ],\n        }\n\n    for cam in cameras:\n        features[f\"observation.images.{cam}\"] = {\n            \"dtype\": mode,\n            \"shape\": (3, 480, 640),\n            \"names\": [\n                \"channels\",\n                \"height\",\n                \"width\",\n            ],\n        }\n\n    if Path(LEROBOT_HOME / repo_id).exists():\n        shutil.rmtree(LEROBOT_HOME / repo_id)\n\n    return LeRobotDataset.create(\n        repo_id=repo_id,\n        fps=50,\n        robot_type=robot_type,\n        features=features,\n        use_videos=dataset_config.use_videos,\n        tolerance_s=dataset_config.tolerance_s,\n        image_writer_processes=dataset_config.image_writer_processes,\n        image_writer_threads=dataset_config.image_writer_threads,\n        video_backend=dataset_config.video_backend,\n    )\n\n\ndef get_cameras(hdf5_files: list[Path]) -> list[str]:\n    with h5py.File(hdf5_files[0], \"r\") as ep:\n        # ignore depth channel, not currently handled\n        return [key for key in ep[\"/observations/images\"].keys() if \"depth\" not in key]  # noqa: SIM118\n\n\ndef has_velocity(hdf5_files: list[Path]) -> bool:\n    with h5py.File(hdf5_files[0], \"r\") as ep:\n        return \"/observations/qvel\" in ep\n\n\ndef has_effort(hdf5_files: list[Path]) -> bool:\n    with h5py.File(hdf5_files[0], \"r\") as ep:\n        return \"/observations/effort\" in ep\n\n\ndef load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:\n    imgs_per_cam = {}\n    for camera in cameras:\n        uncompressed = ep[f\"/observations/images/{camera}\"].ndim == 4\n\n        if uncompressed:\n            # load all images in RAM\n            imgs_array = ep[f\"/observations/images/{camera}\"][:]\n        else:\n            import cv2\n\n            # load one compressed image after the other in RAM and uncompress\n            imgs_array = []\n            for data in ep[f\"/observations/images/{camera}\"]:\n                imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB))\n            imgs_array = np.array(imgs_array)\n\n        imgs_per_cam[camera] = imgs_array\n    return imgs_per_cam\n\n\ndef load_raw_episode_data(\n    ep_path: Path,\n) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:\n    with h5py.File(ep_path, \"r\") as ep:\n        state = torch.from_numpy(ep[\"/observations/qpos\"][:])\n        action = torch.from_numpy(ep[\"/action\"][:])\n\n        velocity = None\n        if \"/observations/qvel\" in ep:\n            velocity = torch.from_numpy(ep[\"/observations/qvel\"][:])\n\n        effort = None\n        if \"/observations/effort\" in ep:\n            effort = torch.from_numpy(ep[\"/observations/effort\"][:])\n\n        imgs_per_cam = load_raw_images_per_camera(\n            ep,\n            [\n                \"cam_high\",\n                \"cam_low\",\n                \"cam_left_wrist\",\n                \"cam_right_wrist\",\n            ],\n        )\n\n    return imgs_per_cam, state, action, velocity, effort\n\n\ndef populate_dataset(\n    dataset: LeRobotDataset,\n    hdf5_files: list[Path],\n    task: str,\n    episodes: list[int] | None = None,\n) -> LeRobotDataset:\n    if episodes is None:\n        episodes = range(len(hdf5_files))\n\n    for ep_idx in tqdm.tqdm(episodes):\n        ep_path = hdf5_files[ep_idx]\n\n        imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)\n        num_frames = state.shape[0]\n\n        for i in range(num_frames):\n            frame = {\n                \"observation.state\": state[i],\n                \"action\": action[i],\n            }\n\n            for camera, img_array in imgs_per_cam.items():\n                frame[f\"observation.images.{camera}\"] = img_array[i]\n\n            if velocity is not None:\n                frame[\"observation.velocity\"] = velocity[i]\n            if effort is not None:\n                frame[\"observation.effort\"] = effort[i]\n\n            dataset.add_frame(frame)\n\n        dataset.save_episode(task=task)\n\n    return dataset\n\n\ndef port_aloha(\n    raw_dir: Path,\n    repo_id: str,\n    raw_repo_id: str | None = None,\n    task: str = \"DEBUG\",\n    *,\n    episodes: list[int] | None = None,\n    push_to_hub: bool = True,\n    is_mobile: bool = False,\n    mode: Literal[\"video\", \"image\"] = \"image\",\n    dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,\n):\n    if (LEROBOT_HOME / repo_id).exists():\n        shutil.rmtree(LEROBOT_HOME / repo_id)\n\n    if not raw_dir.exists():\n        if raw_repo_id is None:\n            raise ValueError(\"raw_repo_id must be provided if raw_dir does not exist\")\n        download_raw(raw_dir, repo_id=raw_repo_id)\n\n    hdf5_files = sorted(raw_dir.glob(\"episode_*.hdf5\"))\n\n    dataset = create_empty_dataset(\n        repo_id,\n        robot_type=\"mobile_aloha\" if is_mobile else \"aloha\",\n        mode=mode,\n        has_effort=has_effort(hdf5_files),\n        has_velocity=has_velocity(hdf5_files),\n        dataset_config=dataset_config,\n    )\n    dataset = populate_dataset(\n        dataset,\n        hdf5_files,\n        task=task,\n        episodes=episodes,\n    )\n    dataset.consolidate()\n\n    if push_to_hub:\n        dataset.push_to_hub()\n\n\nif __name__ == \"__main__\":\n    tyro.cli(port_aloha)\n"
  },
  {
    "path": "examples/aloha_real/env.py",
    "content": "from typing import List, Optional  # noqa: UP035\n\nimport einops\nfrom openpi_client import image_tools\nfrom openpi_client.runtime import environment as _environment\nfrom typing_extensions import override\n\nfrom examples.aloha_real import real_env as _real_env\n\n\nclass AlohaRealEnvironment(_environment.Environment):\n    \"\"\"An environment for an Aloha robot on real hardware.\"\"\"\n\n    def __init__(\n        self,\n        reset_position: Optional[List[float]] = None,  # noqa: UP006,UP007\n        render_height: int = 224,\n        render_width: int = 224,\n    ) -> None:\n        self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position)\n        self._render_height = render_height\n        self._render_width = render_width\n\n        self._ts = None\n\n    @override\n    def reset(self) -> None:\n        self._ts = self._env.reset()\n\n    @override\n    def is_episode_complete(self) -> bool:\n        return False\n\n    @override\n    def get_observation(self) -> dict:\n        if self._ts is None:\n            raise RuntimeError(\"Timestep is not set. Call reset() first.\")\n\n        obs = self._ts.observation\n        for k in list(obs[\"images\"].keys()):\n            if \"_depth\" in k:\n                del obs[\"images\"][k]\n\n        for cam_name in obs[\"images\"]:\n            img = image_tools.convert_to_uint8(\n                image_tools.resize_with_pad(obs[\"images\"][cam_name], self._render_height, self._render_width)\n            )\n            obs[\"images\"][cam_name] = einops.rearrange(img, \"h w c -> c h w\")\n\n        return {\n            \"state\": obs[\"qpos\"],\n            \"images\": obs[\"images\"],\n        }\n\n    @override\n    def apply_action(self, action: dict) -> None:\n        self._ts = self._env.step(action[\"actions\"])\n"
  },
  {
    "path": "examples/aloha_real/main.py",
    "content": "import dataclasses\nimport logging\n\nfrom openpi_client import action_chunk_broker\nfrom openpi_client import websocket_client_policy as _websocket_client_policy\nfrom openpi_client.runtime import runtime as _runtime\nfrom openpi_client.runtime.agents import policy_agent as _policy_agent\nimport tyro\n\nfrom examples.aloha_real import env as _env\n\n\n@dataclasses.dataclass\nclass Args:\n    host: str = \"0.0.0.0\"\n    port: int = 8000\n\n    action_horizon: int = 25\n\n    num_episodes: int = 1\n    max_episode_steps: int = 1000\n\n\ndef main(args: Args) -> None:\n    ws_client_policy = _websocket_client_policy.WebsocketClientPolicy(\n        host=args.host,\n        port=args.port,\n    )\n    logging.info(f\"Server metadata: {ws_client_policy.get_server_metadata()}\")\n\n    metadata = ws_client_policy.get_server_metadata()\n    runtime = _runtime.Runtime(\n        environment=_env.AlohaRealEnvironment(reset_position=metadata.get(\"reset_pose\")),\n        agent=_policy_agent.PolicyAgent(\n            policy=action_chunk_broker.ActionChunkBroker(\n                policy=ws_client_policy,\n                action_horizon=args.action_horizon,\n            )\n        ),\n        subscribers=[],\n        max_hz=50,\n        num_episodes=args.num_episodes,\n        max_episode_steps=args.max_episode_steps,\n    )\n\n    runtime.run()\n\n\nif __name__ == \"__main__\":\n    logging.basicConfig(level=logging.INFO, force=True)\n    tyro.cli(main)\n"
  },
  {
    "path": "examples/aloha_real/real_env.py",
    "content": "# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).\n# ruff: noqa\nimport collections\nimport time\nfrom typing import Optional, List\nimport dm_env\nfrom interbotix_xs_modules.arm import InterbotixManipulatorXS\nfrom interbotix_xs_msgs.msg import JointSingleCommand\nimport numpy as np\n\nfrom examples.aloha_real import constants\nfrom examples.aloha_real import robot_utils\n\n# This is the reset position that is used by the standard Aloha runtime.\nDEFAULT_RESET_POSITION = [0, -0.96, 1.16, 0, -0.3, 0]\n\n\nclass RealEnv:\n    \"\"\"\n    Environment for real robot bi-manual manipulation\n    Action space:      [left_arm_qpos (6),             # absolute joint position\n                        left_gripper_positions (1),    # normalized gripper position (0: close, 1: open)\n                        right_arm_qpos (6),            # absolute joint position\n                        right_gripper_positions (1),]  # normalized gripper position (0: close, 1: open)\n\n    Observation space: {\"qpos\": Concat[ left_arm_qpos (6),          # absolute joint position\n                                        left_gripper_position (1),  # normalized gripper position (0: close, 1: open)\n                                        right_arm_qpos (6),         # absolute joint position\n                                        right_gripper_qpos (1)]     # normalized gripper position (0: close, 1: open)\n                        \"qvel\": Concat[ left_arm_qvel (6),         # absolute joint velocity (rad)\n                                        left_gripper_velocity (1),  # normalized gripper velocity (pos: opening, neg: closing)\n                                        right_arm_qvel (6),         # absolute joint velocity (rad)\n                                        right_gripper_qvel (1)]     # normalized gripper velocity (pos: opening, neg: closing)\n                        \"images\": {\"cam_high\": (480x640x3),        # h, w, c, dtype='uint8'\n                                   \"cam_low\": (480x640x3),         # h, w, c, dtype='uint8'\n                                   \"cam_left_wrist\": (480x640x3),  # h, w, c, dtype='uint8'\n                                   \"cam_right_wrist\": (480x640x3)} # h, w, c, dtype='uint8'\n    \"\"\"\n\n    def __init__(self, init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True):\n        # reset_position = START_ARM_POSE[:6]\n        self._reset_position = reset_position[:6] if reset_position else DEFAULT_RESET_POSITION\n\n        self.puppet_bot_left = InterbotixManipulatorXS(\n            robot_model=\"vx300s\",\n            group_name=\"arm\",\n            gripper_name=\"gripper\",\n            robot_name=\"puppet_left\",\n            init_node=init_node,\n        )\n        self.puppet_bot_right = InterbotixManipulatorXS(\n            robot_model=\"vx300s\", group_name=\"arm\", gripper_name=\"gripper\", robot_name=\"puppet_right\", init_node=False\n        )\n        if setup_robots:\n            self.setup_robots()\n\n        self.recorder_left = robot_utils.Recorder(\"left\", init_node=False)\n        self.recorder_right = robot_utils.Recorder(\"right\", init_node=False)\n        self.image_recorder = robot_utils.ImageRecorder(init_node=False)\n        self.gripper_command = JointSingleCommand(name=\"gripper\")\n\n    def setup_robots(self):\n        robot_utils.setup_puppet_bot(self.puppet_bot_left)\n        robot_utils.setup_puppet_bot(self.puppet_bot_right)\n\n    def get_qpos(self):\n        left_qpos_raw = self.recorder_left.qpos\n        right_qpos_raw = self.recorder_right.qpos\n        left_arm_qpos = left_qpos_raw[:6]\n        right_arm_qpos = right_qpos_raw[:6]\n        left_gripper_qpos = [\n            constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])\n        ]  # this is position not joint\n        right_gripper_qpos = [\n            constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])\n        ]  # this is position not joint\n        return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])\n\n    def get_qvel(self):\n        left_qvel_raw = self.recorder_left.qvel\n        right_qvel_raw = self.recorder_right.qvel\n        left_arm_qvel = left_qvel_raw[:6]\n        right_arm_qvel = right_qvel_raw[:6]\n        left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]\n        right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]\n        return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])\n\n    def get_effort(self):\n        left_effort_raw = self.recorder_left.effort\n        right_effort_raw = self.recorder_right.effort\n        left_robot_effort = left_effort_raw[:7]\n        right_robot_effort = right_effort_raw[:7]\n        return np.concatenate([left_robot_effort, right_robot_effort])\n\n    def get_images(self):\n        return self.image_recorder.get_images()\n\n    def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):\n        left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)\n        self.gripper_command.cmd = left_gripper_desired_joint\n        self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)\n\n        right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(\n            right_gripper_desired_pos_normalized\n        )\n        self.gripper_command.cmd = right_gripper_desired_joint\n        self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)\n\n    def _reset_joints(self):\n        robot_utils.move_arms(\n            [self.puppet_bot_left, self.puppet_bot_right], [self._reset_position, self._reset_position], move_time=1\n        )\n\n    def _reset_gripper(self):\n        \"\"\"Set to position mode and do position resets: first close then open. Then change back to PWM mode\n\n        NOTE: This diverges from the original Aloha code which first opens then closes the gripper. Pi internal aloha data\n        was collected with the gripper starting in the open position. Leaving the grippers fully closed was also found to\n        increase the frequency of motor faults.\n        \"\"\"\n        robot_utils.move_grippers(\n            [self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1\n        )\n        robot_utils.move_grippers(\n            [self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5\n        )\n\n    def get_observation(self):\n        obs = collections.OrderedDict()\n        obs[\"qpos\"] = self.get_qpos()\n        obs[\"qvel\"] = self.get_qvel()\n        obs[\"effort\"] = self.get_effort()\n        obs[\"images\"] = self.get_images()\n        return obs\n\n    def get_reward(self):\n        return 0\n\n    def reset(self, *, fake=False):\n        if not fake:\n            # Reboot puppet robot gripper motors\n            self.puppet_bot_left.dxl.robot_reboot_motors(\"single\", \"gripper\", True)\n            self.puppet_bot_right.dxl.robot_reboot_motors(\"single\", \"gripper\", True)\n            self._reset_joints()\n            self._reset_gripper()\n        return dm_env.TimeStep(\n            step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()\n        )\n\n    def step(self, action):\n        state_len = int(len(action) / 2)\n        left_action = action[:state_len]\n        right_action = action[state_len:]\n        self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)\n        self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)\n        self.set_gripper_pose(left_action[-1], right_action[-1])\n        time.sleep(constants.DT)\n        return dm_env.TimeStep(\n            step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()\n        )\n\n\ndef get_action(master_bot_left, master_bot_right):\n    action = np.zeros(14)  # 6 joint + 1 gripper, for two arms\n    # Arm actions\n    action[:6] = master_bot_left.dxl.joint_states.position[:6]\n    action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]\n    # Gripper actions\n    action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])\n    action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])\n\n    return action\n\n\ndef make_real_env(init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True) -> RealEnv:\n    return RealEnv(init_node, reset_position=reset_position, setup_robots=setup_robots)\n"
  },
  {
    "path": "examples/aloha_real/requirements.in",
    "content": "Pillow\ndm_control\neinops\nh5py\nmatplotlib\nmodern_robotics\nmsgpack\nnumpy>=1.22.4,<2.0.0\nopencv-python\npackaging\npexpect\npyquaternion\npyrealsense2\npyyaml\nrequests\nrospkg\ntyro\nwebsockets\n"
  },
  {
    "path": "examples/aloha_real/requirements.txt",
    "content": "# This file was autogenerated by uv via the following command:\n#    uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10\nabsl-py==2.1.0\n    # via\n    #   dm-control\n    #   dm-env\n    #   labmaze\n    #   mujoco\ncatkin-pkg==1.0.0\n    # via rospkg\ncertifi==2024.8.30\n    # via requests\ncharset-normalizer==3.4.0\n    # via requests\ncontourpy==1.1.1\n    # via matplotlib\ncycler==0.12.1\n    # via matplotlib\ndistro==1.9.0\n    # via rospkg\ndm-control==1.0.23\n    # via -r examples/aloha_real/requirements.in\ndm-env==1.6\n    # via dm-control\ndm-tree==0.1.8\n    # via\n    #   dm-control\n    #   dm-env\ndocstring-parser==0.16\n    # via tyro\ndocutils==0.20.1\n    # via catkin-pkg\neinops==0.8.0\n    # via -r examples/aloha_real/requirements.in\netils==1.3.0\n    # via mujoco\nfonttools==4.55.2\n    # via matplotlib\nglfw==2.8.0\n    # via\n    #   dm-control\n    #   mujoco\nh5py==3.11.0\n    # via -r examples/aloha_real/requirements.in\nidna==3.10\n    # via requests\nimportlib-resources==6.4.5\n    # via etils\nkiwisolver==1.4.7\n    # via matplotlib\nlabmaze==1.0.6\n    # via dm-control\nlxml==5.3.0\n    # via dm-control\nmarkdown-it-py==3.0.0\n    # via rich\nmatplotlib==3.7.5\n    # via -r examples/aloha_real/requirements.in\nmdurl==0.1.2\n    # via markdown-it-py\nmodern-robotics==1.1.1\n    # via -r examples/aloha_real/requirements.in\nmsgpack==1.1.0\n    # via -r examples/aloha_real/requirements.in\nmujoco==3.2.3\n    # via dm-control\nnumpy==1.24.4\n    # via\n    #   -r examples/aloha_real/requirements.in\n    #   contourpy\n    #   dm-control\n    #   dm-env\n    #   h5py\n    #   labmaze\n    #   matplotlib\n    #   modern-robotics\n    #   mujoco\n    #   opencv-python\n    #   pyquaternion\n    #   scipy\nopencv-python==4.10.0.84\n    # via -r examples/aloha_real/requirements.in\npackaging==24.2\n    # via\n    #   -r examples/aloha_real/requirements.in\n    #   matplotlib\npexpect==4.9.0\n    # via -r examples/aloha_real/requirements.in\npillow==10.4.0\n    # via\n    #   -r examples/aloha_real/requirements.in\n    #   matplotlib\nprotobuf==5.29.1\n    # via dm-control\nptyprocess==0.7.0\n    # via pexpect\npygments==2.18.0\n    # via rich\npyopengl==3.1.7\n    # via\n    #   dm-control\n    #   mujoco\npyparsing==3.1.4\n    # via\n    #   catkin-pkg\n    #   dm-control\n    #   matplotlib\npyquaternion==0.9.9\n    # via -r examples/aloha_real/requirements.in\npyrealsense2==2.55.1.6486\n    # via -r examples/aloha_real/requirements.in\npython-dateutil==2.9.0.post0\n    # via\n    #   catkin-pkg\n    #   matplotlib\npyyaml==6.0.2\n    # via\n    #   -r examples/aloha_real/requirements.in\n    #   rospkg\nrequests==2.32.3\n    # via\n    #   -r examples/aloha_real/requirements.in\n    #   dm-control\nrich==13.9.4\n    # via tyro\nrospkg==1.5.1\n    # via -r examples/aloha_real/requirements.in\nscipy==1.10.1\n    # via dm-control\nsetuptools==75.3.0\n    # via\n    #   catkin-pkg\n    #   dm-control\n    #   labmaze\nshtab==1.7.1\n    # via tyro\nsix==1.17.0\n    # via python-dateutil\ntqdm==4.67.1\n    # via dm-control\ntypeguard==4.4.0\n    # via tyro\ntyping-extensions==4.12.2\n    # via\n    #   etils\n    #   rich\n    #   typeguard\n    #   tyro\ntyro==0.9.2\n    # via -r examples/aloha_real/requirements.in\nurllib3==2.2.3\n    # via requests\nwebsockets==14.1\n    # via -r examples/aloha_real/requirements.in\nzipp==3.20.2\n    # via etils\n"
  },
  {
    "path": "examples/aloha_real/robot_utils.py",
    "content": "# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).\n# ruff: noqa\nfrom collections import deque\nimport datetime\nimport json\nimport time\n\nfrom aloha.msg import RGBGrayscaleImage\nfrom cv_bridge import CvBridge\nfrom interbotix_xs_msgs.msg import JointGroupCommand\nfrom interbotix_xs_msgs.msg import JointSingleCommand\nimport numpy as np\nimport rospy\nfrom sensor_msgs.msg import JointState\n\nfrom examples.aloha_real import constants\n\n\nclass ImageRecorder:\n    def __init__(self, init_node=True, is_debug=False):\n        self.is_debug = is_debug\n        self.bridge = CvBridge()\n        self.camera_names = [\"cam_high\", \"cam_low\", \"cam_left_wrist\", \"cam_right_wrist\"]\n\n        if init_node:\n            rospy.init_node(\"image_recorder\", anonymous=True)\n        for cam_name in self.camera_names:\n            setattr(self, f\"{cam_name}_rgb_image\", None)\n            setattr(self, f\"{cam_name}_depth_image\", None)\n            setattr(self, f\"{cam_name}_timestamp\", 0.0)\n            if cam_name == \"cam_high\":\n                callback_func = self.image_cb_cam_high\n            elif cam_name == \"cam_low\":\n                callback_func = self.image_cb_cam_low\n            elif cam_name == \"cam_left_wrist\":\n                callback_func = self.image_cb_cam_left_wrist\n            elif cam_name == \"cam_right_wrist\":\n                callback_func = self.image_cb_cam_right_wrist\n            else:\n                raise NotImplementedError\n            rospy.Subscriber(f\"/{cam_name}\", RGBGrayscaleImage, callback_func)\n            if self.is_debug:\n                setattr(self, f\"{cam_name}_timestamps\", deque(maxlen=50))\n\n        self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}\n        time.sleep(0.5)\n\n    def image_cb(self, cam_name, data):\n        setattr(\n            self,\n            f\"{cam_name}_rgb_image\",\n            self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding=\"bgr8\"),\n        )\n        # setattr(\n        #     self,\n        #     f\"{cam_name}_depth_image\",\n        #     self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding=\"mono16\"),\n        # )\n        setattr(\n            self,\n            f\"{cam_name}_timestamp\",\n            data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,\n        )\n        # setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)\n        # setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)\n        # cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)\n        if self.is_debug:\n            getattr(self, f\"{cam_name}_timestamps\").append(\n                data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9\n            )\n\n    def image_cb_cam_high(self, data):\n        cam_name = \"cam_high\"\n        return self.image_cb(cam_name, data)\n\n    def image_cb_cam_low(self, data):\n        cam_name = \"cam_low\"\n        return self.image_cb(cam_name, data)\n\n    def image_cb_cam_left_wrist(self, data):\n        cam_name = \"cam_left_wrist\"\n        return self.image_cb(cam_name, data)\n\n    def image_cb_cam_right_wrist(self, data):\n        cam_name = \"cam_right_wrist\"\n        return self.image_cb(cam_name, data)\n\n    def get_images(self):\n        image_dict = {}\n        for cam_name in self.camera_names:\n            while getattr(self, f\"{cam_name}_timestamp\") <= self.cam_last_timestamps[cam_name]:\n                time.sleep(0.00001)\n            rgb_image = getattr(self, f\"{cam_name}_rgb_image\")\n            depth_image = getattr(self, f\"{cam_name}_depth_image\")\n            self.cam_last_timestamps[cam_name] = getattr(self, f\"{cam_name}_timestamp\")\n            image_dict[cam_name] = rgb_image\n            image_dict[f\"{cam_name}_depth\"] = depth_image\n        return image_dict\n\n    def print_diagnostics(self):\n        def dt_helper(l):\n            l = np.array(l)\n            diff = l[1:] - l[:-1]\n            return np.mean(diff)\n\n        for cam_name in self.camera_names:\n            image_freq = 1 / dt_helper(getattr(self, f\"{cam_name}_timestamps\"))\n            print(f\"{cam_name} {image_freq=:.2f}\")\n        print()\n\n\nclass Recorder:\n    def __init__(self, side, init_node=True, is_debug=False):\n        self.secs = None\n        self.nsecs = None\n        self.qpos = None\n        self.effort = None\n        self.arm_command = None\n        self.gripper_command = None\n        self.is_debug = is_debug\n\n        if init_node:\n            rospy.init_node(\"recorder\", anonymous=True)\n        rospy.Subscriber(f\"/puppet_{side}/joint_states\", JointState, self.puppet_state_cb)\n        rospy.Subscriber(\n            f\"/puppet_{side}/commands/joint_group\",\n            JointGroupCommand,\n            self.puppet_arm_commands_cb,\n        )\n        rospy.Subscriber(\n            f\"/puppet_{side}/commands/joint_single\",\n            JointSingleCommand,\n            self.puppet_gripper_commands_cb,\n        )\n        if self.is_debug:\n            self.joint_timestamps = deque(maxlen=50)\n            self.arm_command_timestamps = deque(maxlen=50)\n            self.gripper_command_timestamps = deque(maxlen=50)\n        time.sleep(0.1)\n\n    def puppet_state_cb(self, data):\n        self.qpos = data.position\n        self.qvel = data.velocity\n        self.effort = data.effort\n        self.data = data\n        if self.is_debug:\n            self.joint_timestamps.append(time.time())\n\n    def puppet_arm_commands_cb(self, data):\n        self.arm_command = data.cmd\n        if self.is_debug:\n            self.arm_command_timestamps.append(time.time())\n\n    def puppet_gripper_commands_cb(self, data):\n        self.gripper_command = data.cmd\n        if self.is_debug:\n            self.gripper_command_timestamps.append(time.time())\n\n    def print_diagnostics(self):\n        def dt_helper(l):\n            l = np.array(l)\n            diff = l[1:] - l[:-1]\n            return np.mean(diff)\n\n        joint_freq = 1 / dt_helper(self.joint_timestamps)\n        arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)\n        gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)\n\n        print(f\"{joint_freq=:.2f}\\n{arm_command_freq=:.2f}\\n{gripper_command_freq=:.2f}\\n\")\n\n\ndef get_arm_joint_positions(bot):\n    return bot.arm.core.joint_states.position[:6]\n\n\ndef get_arm_gripper_positions(bot):\n    return bot.gripper.core.joint_states.position[6]\n\n\ndef move_arms(bot_list, target_pose_list, move_time=1):\n    num_steps = int(move_time / constants.DT)\n    curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]\n    traj_list = [\n        np.linspace(curr_pose, target_pose, num_steps)\n        for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)\n    ]\n    for t in range(num_steps):\n        for bot_id, bot in enumerate(bot_list):\n            bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)\n        time.sleep(constants.DT)\n\n\ndef move_grippers(bot_list, target_pose_list, move_time):\n    print(f\"Moving grippers to {target_pose_list=}\")\n    gripper_command = JointSingleCommand(name=\"gripper\")\n    num_steps = int(move_time / constants.DT)\n    curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]\n    traj_list = [\n        np.linspace(curr_pose, target_pose, num_steps)\n        for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)\n    ]\n\n    with open(f\"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl\", \"a\") as f:\n        for t in range(num_steps):\n            d = {}\n            for bot_id, bot in enumerate(bot_list):\n                gripper_command.cmd = traj_list[bot_id][t]\n                bot.gripper.core.pub_single.publish(gripper_command)\n                d[bot_id] = {\"obs\": get_arm_gripper_positions(bot), \"act\": traj_list[bot_id][t]}\n            f.write(json.dumps(d) + \"\\n\")\n            time.sleep(constants.DT)\n\n\ndef setup_puppet_bot(bot):\n    bot.dxl.robot_reboot_motors(\"single\", \"gripper\", True)\n    bot.dxl.robot_set_operating_modes(\"group\", \"arm\", \"position\")\n    bot.dxl.robot_set_operating_modes(\"single\", \"gripper\", \"current_based_position\")\n    torque_on(bot)\n\n\ndef setup_master_bot(bot):\n    bot.dxl.robot_set_operating_modes(\"group\", \"arm\", \"pwm\")\n    bot.dxl.robot_set_operating_modes(\"single\", \"gripper\", \"current_based_position\")\n    torque_off(bot)\n\n\ndef set_standard_pid_gains(bot):\n    bot.dxl.robot_set_motor_registers(\"group\", \"arm\", \"Position_P_Gain\", 800)\n    bot.dxl.robot_set_motor_registers(\"group\", \"arm\", \"Position_I_Gain\", 0)\n\n\ndef set_low_pid_gains(bot):\n    bot.dxl.robot_set_motor_registers(\"group\", \"arm\", \"Position_P_Gain\", 100)\n    bot.dxl.robot_set_motor_registers(\"group\", \"arm\", \"Position_I_Gain\", 0)\n\n\ndef torque_off(bot):\n    bot.dxl.robot_torque_enable(\"group\", \"arm\", False)\n    bot.dxl.robot_torque_enable(\"single\", \"gripper\", False)\n\n\ndef torque_on(bot):\n    bot.dxl.robot_torque_enable(\"group\", \"arm\", True)\n    bot.dxl.robot_torque_enable(\"single\", \"gripper\", True)\n\n\n# for DAgger\ndef sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):\n    print(\"\\nSyncing!\")\n\n    # activate master arms\n    torque_on(master_bot_left)\n    torque_on(master_bot_right)\n\n    # get puppet arm positions\n    puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)\n    puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)\n\n    # get puppet gripper positions\n    puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)\n    puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)\n\n    # move master arms to puppet positions\n    move_arms(\n        [master_bot_left, master_bot_right],\n        [puppet_left_qpos, puppet_right_qpos],\n        move_time=1,\n    )\n\n    # move master grippers to puppet positions\n    move_grippers(\n        [master_bot_left, master_bot_right],\n        [puppet_left_gripper, puppet_right_gripper],\n        move_time=1,\n    )\n"
  },
  {
    "path": "examples/aloha_real/video_display.py",
    "content": "import matplotlib.pyplot as plt\nimport numpy as np\nfrom openpi_client.runtime import subscriber as _subscriber\nfrom typing_extensions import override\n\n\nclass VideoDisplay(_subscriber.Subscriber):\n    \"\"\"Displays video frames.\"\"\"\n\n    def __init__(self) -> None:\n        self._ax: plt.Axes | None = None\n        self._plt_img: plt.Image | None = None\n\n    @override\n    def on_episode_start(self) -> None:\n        plt.ion()\n        self._ax = plt.subplot()\n        self._plt_img = None\n\n    @override\n    def on_step(self, observation: dict, action: dict) -> None:\n        assert self._ax is not None\n\n        im = observation[\"image\"][0]  # [C, H, W]\n        im = np.transpose(im, (1, 2, 0))  # [H, W, C]\n\n        if self._plt_img is None:\n            self._plt_img = self._ax.imshow(im)\n        else:\n            self._plt_img.set_data(im)\n        plt.pause(0.001)\n\n    @override\n    def on_episode_end(self) -> None:\n        plt.ioff()\n        plt.close()\n"
  },
  {
    "path": "examples/aloha_sim/Dockerfile",
    "content": "# Dockerfile for the Aloha simulation environment.\n\n# Build the container:\n# docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile\n\n# Run the container:\n# docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash\n\nFROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78\nCOPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/\n\nRUN apt-get update && \\\n    apt-get install -y \\\n    libosmesa6-dev \\\n    libgl1-mesa-glx \\\n    libglew-dev \\\n    libglfw3-dev \\\n    libgles2-mesa-dev\nENV MUJOCO_GL=egl\n\nWORKDIR /app\n\n# Copy from the cache instead of linking since it's a mounted volume\nENV UV_LINK_MODE=copy\n\n# Write the virtual environment outside of the project directory so it doesn't\n# leak out of the container when we mount the application code.\nENV UV_PROJECT_ENVIRONMENT=/.venv\n\n# Copy the requirements files so we can install dependencies.\n# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.\n# This strategy is best for development-style usage.\nCOPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt\nCOPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml\n\n# Install python dependencies.\nRUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT\nRUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml\nENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src\n\nCMD [\"/bin/bash\", \"-c\", \"source /.venv/bin/activate && python examples/aloha_sim/main.py\"]"
  },
  {
    "path": "examples/aloha_sim/README.md",
    "content": "# Run Aloha Sim\n\n## With Docker\n\n```bash\nexport SERVER_ARGS=\"--env ALOHA_SIM\"\ndocker compose -f examples/aloha_sim/compose.yml up --build\n```\n\n## Without Docker\n\nTerminal window 1:\n\n```bash\n# Create virtual environment\nuv venv --python 3.10 examples/aloha_sim/.venv\nsource examples/aloha_sim/.venv/bin/activate\nuv pip sync examples/aloha_sim/requirements.txt\nuv pip install -e packages/openpi-client\n\n# Run the simulation\nMUJOCO_GL=egl python examples/aloha_sim/main.py\n```\n\nNote: If you are seeing EGL errors, you may need to install the following dependencies:\n\n```bash\nsudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev\n```\n\nTerminal window 2:\n\n```bash\n# Run the server\nuv run scripts/serve_policy.py --env ALOHA_SIM\n```\n"
  },
  {
    "path": "examples/aloha_sim/compose.yml",
    "content": "# Run with:\n# docker compose -f examples/aloha_sim/compose.yml up --build\nservices:\n  runtime:\n    image: aloha_sim\n    depends_on:\n      - openpi_server\n    build:\n      context: ../..\n      dockerfile: examples/aloha_sim/Dockerfile\n    init: true\n    tty: true\n    network_mode: host\n    privileged: true\n    volumes:\n      - $PWD:/app\n      - ../../data:/data\n\n  openpi_server:\n    image: openpi_server\n    build:\n      context: ../..\n      dockerfile: scripts/docker/serve_policy.Dockerfile\n    init: true\n    tty: true\n    network_mode: host\n    volumes:\n      - $PWD:/app\n      - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets\n    environment:\n      - SERVER_ARGS\n      - OPENPI_DATA_HOME=/openpi_assets\n      - IS_DOCKER=true\n\n    # Comment out this block if not running on a machine with GPUs.\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              count: 1\n              capabilities: [gpu]\n"
  },
  {
    "path": "examples/aloha_sim/env.py",
    "content": "import gym_aloha  # noqa: F401\nimport gymnasium\nimport numpy as np\nfrom openpi_client import image_tools\nfrom openpi_client.runtime import environment as _environment\nfrom typing_extensions import override\n\n\nclass AlohaSimEnvironment(_environment.Environment):\n    \"\"\"An environment for an Aloha robot in simulation.\"\"\"\n\n    def __init__(self, task: str, obs_type: str = \"pixels_agent_pos\", seed: int = 0) -> None:\n        np.random.seed(seed)\n        self._rng = np.random.default_rng(seed)\n\n        self._gym = gymnasium.make(task, obs_type=obs_type)\n\n        self._last_obs = None\n        self._done = True\n        self._episode_reward = 0.0\n\n    @override\n    def reset(self) -> None:\n        gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))\n        self._last_obs = self._convert_observation(gym_obs)  # type: ignore\n        self._done = False\n        self._episode_reward = 0.0\n\n    @override\n    def is_episode_complete(self) -> bool:\n        return self._done\n\n    @override\n    def get_observation(self) -> dict:\n        if self._last_obs is None:\n            raise RuntimeError(\"Observation is not set. Call reset() first.\")\n\n        return self._last_obs  # type: ignore\n\n    @override\n    def apply_action(self, action: dict) -> None:\n        gym_obs, reward, terminated, truncated, info = self._gym.step(action[\"actions\"])\n        self._last_obs = self._convert_observation(gym_obs)  # type: ignore\n        self._done = terminated or truncated\n        self._episode_reward = max(self._episode_reward, reward)\n\n    def _convert_observation(self, gym_obs: dict) -> dict:\n        img = gym_obs[\"pixels\"][\"top\"]\n        img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224))\n        # Convert axis order from [H, W, C] --> [C, H, W]\n        img = np.transpose(img, (2, 0, 1))\n\n        return {\n            \"state\": gym_obs[\"agent_pos\"],\n            \"images\": {\"cam_high\": img},\n        }\n"
  },
  {
    "path": "examples/aloha_sim/main.py",
    "content": "import dataclasses\nimport logging\nimport pathlib\n\nimport env as _env\nfrom openpi_client import action_chunk_broker\nfrom openpi_client import websocket_client_policy as _websocket_client_policy\nfrom openpi_client.runtime import runtime as _runtime\nfrom openpi_client.runtime.agents import policy_agent as _policy_agent\nimport saver as _saver\nimport tyro\n\n\n@dataclasses.dataclass\nclass Args:\n    out_dir: pathlib.Path = pathlib.Path(\"data/aloha_sim/videos\")\n\n    task: str = \"gym_aloha/AlohaTransferCube-v0\"\n    seed: int = 0\n\n    action_horizon: int = 10\n\n    host: str = \"0.0.0.0\"\n    port: int = 8000\n\n    display: bool = False\n\n\ndef main(args: Args) -> None:\n    runtime = _runtime.Runtime(\n        environment=_env.AlohaSimEnvironment(\n            task=args.task,\n            seed=args.seed,\n        ),\n        agent=_policy_agent.PolicyAgent(\n            policy=action_chunk_broker.ActionChunkBroker(\n                policy=_websocket_client_policy.WebsocketClientPolicy(\n                    host=args.host,\n                    port=args.port,\n                ),\n                action_horizon=args.action_horizon,\n            )\n        ),\n        subscribers=[\n            _saver.VideoSaver(args.out_dir),\n        ],\n        max_hz=50,\n    )\n\n    runtime.run()\n\n\nif __name__ == \"__main__\":\n    logging.basicConfig(level=logging.INFO, force=True)\n    tyro.cli(main)\n"
  },
  {
    "path": "examples/aloha_sim/requirements.in",
    "content": "gym-aloha\nimageio\nmatplotlib\nmsgpack\nnumpy>=1.22.4,<2.0.0\ntyping-extensions\ntyro\nwebsockets"
  },
  {
    "path": "examples/aloha_sim/requirements.txt",
    "content": "# This file was autogenerated by uv via the following command:\n#    uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10\nabsl-py==2.1.0\n    # via\n    #   dm-control\n    #   dm-env\n    #   labmaze\n    #   mujoco\ncertifi==2024.8.30\n    # via requests\ncharset-normalizer==3.4.0\n    # via requests\ncloudpickle==3.1.0\n    # via gymnasium\ncontourpy==1.3.1\n    # via matplotlib\ncycler==0.12.1\n    # via matplotlib\ndm-control==1.0.14\n    # via gym-aloha\ndm-env==1.6\n    # via dm-control\ndm-tree==0.1.8\n    # via\n    #   dm-control\n    #   dm-env\ndocstring-parser==0.16\n    # via tyro\nfarama-notifications==0.0.4\n    # via gymnasium\nfonttools==4.55.2\n    # via matplotlib\nglfw==2.8.0\n    # via\n    #   dm-control\n    #   mujoco\ngym-aloha==0.1.1\n    # via -r examples/aloha_sim/requirements.in\ngymnasium==1.0.0\n    # via gym-aloha\nidna==3.10\n    # via requests\nimageio==2.36.1\n    # via\n    #   -r examples/aloha_sim/requirements.in\n    #   gym-aloha\nimageio-ffmpeg==0.5.1\n    # via imageio\nkiwisolver==1.4.7\n    # via matplotlib\nlabmaze==1.0.6\n    # via dm-control\nlxml==5.3.0\n    # via dm-control\nmarkdown-it-py==3.0.0\n    # via rich\nmatplotlib==3.9.3\n    # via -r examples/aloha_sim/requirements.in\nmdurl==0.1.2\n    # via markdown-it-py\nmsgpack==1.1.0\n    # via -r examples/aloha_sim/requirements.in\nmujoco==2.3.7\n    # via\n    #   dm-control\n    #   gym-aloha\nnumpy==1.26.4\n    # via\n    #   -r examples/aloha_sim/requirements.in\n    #   contourpy\n    #   dm-control\n    #   dm-env\n    #   gymnasium\n    #   imageio\n    #   labmaze\n    #   matplotlib\n    #   mujoco\n    #   scipy\npackaging==24.2\n    # via matplotlib\npillow==11.0.0\n    # via\n    #   imageio\n    #   matplotlib\nprotobuf==5.29.1\n    # via dm-control\npsutil==6.1.0\n    # via imageio\npygments==2.18.0\n    # via rich\npyopengl==3.1.7\n    # via\n    #   dm-control\n    #   mujoco\npyparsing==3.2.0\n    # via\n    #   dm-control\n    #   matplotlib\npython-dateutil==2.9.0.post0\n    # via matplotlib\nrequests==2.32.3\n    # via dm-control\nrich==13.9.4\n    # via tyro\nscipy==1.14.1\n    # via dm-control\nsetuptools==75.6.0\n    # via\n    #   dm-control\n    #   imageio-ffmpeg\n    #   labmaze\nshtab==1.7.1\n    # via tyro\nsix==1.17.0\n    # via python-dateutil\ntqdm==4.67.1\n    # via dm-control\ntypeguard==4.4.1\n    # via tyro\ntyping-extensions==4.12.2\n    # via\n    #   -r examples/aloha_sim/requirements.in\n    #   gymnasium\n    #   rich\n    #   typeguard\n    #   tyro\ntyro==0.9.2\n    # via -r examples/aloha_sim/requirements.in\nurllib3==2.2.3\n    # via requests\nwebsockets==14.1\n    # via -r examples/aloha_sim/requirements.in\n"
  },
  {
    "path": "examples/aloha_sim/saver.py",
    "content": "import logging\nimport pathlib\n\nimport imageio\nimport numpy as np\nfrom openpi_client.runtime import subscriber as _subscriber\nfrom typing_extensions import override\n\n\nclass VideoSaver(_subscriber.Subscriber):\n    \"\"\"Saves episode data.\"\"\"\n\n    def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None:\n        out_dir.mkdir(parents=True, exist_ok=True)\n        self._out_dir = out_dir\n        self._images: list[np.ndarray] = []\n        self._subsample = subsample\n\n    @override\n    def on_episode_start(self) -> None:\n        self._images = []\n\n    @override\n    def on_step(self, observation: dict, action: dict) -> None:\n        im = observation[\"images\"][\"cam_high\"]  # [C, H, W]\n        im = np.transpose(im, (1, 2, 0))  # [H, W, C]\n        self._images.append(im)\n\n    @override\n    def on_episode_end(self) -> None:\n        existing = list(self._out_dir.glob(\"out_[0-9]*.mp4\"))\n        next_idx = max([int(p.stem.split(\"_\")[1]) for p in existing], default=-1) + 1\n        out_path = self._out_dir / f\"out_{next_idx}.mp4\"\n\n        logging.info(f\"Saving video to {out_path}\")\n        imageio.mimwrite(\n            out_path,\n            [np.asarray(x) for x in self._images[:: self._subsample]],\n            fps=50 // max(1, self._subsample),\n        )\n"
  },
  {
    "path": "examples/convert_jax_model_to_pytorch.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nLoad a JAX model and print all parameter keys, with optional conversion to PyTorch.\n\nThis script loads a JAX model checkpoint using orbax and can either:\n1. Print out all the parameter keys in a hierarchical structure for inspection\n2. Convert the JAX model to PyTorch format using our PI0Pytorch model\n\nUsage:\n    # Just inspect keys:\n    python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only\n    python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only\n\n    # Convert to PyTorch:\n    python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output\n    python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output\n\nExample:\n    # pi0_droid\n    python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch\n\n    # pi0_aloha_sim\n    python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch\n\n    # pi05_droid\n    python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch\n\"\"\"\n\nimport json\nimport os\nimport pathlib\nimport shutil\nfrom typing import Literal\n\nfrom flax.nnx import traversals\nimport numpy as np\nimport orbax.checkpoint as ocp\nimport safetensors\nimport torch\nimport tyro\n\nimport openpi.models.gemma\nimport openpi.models.model\nimport openpi.models.pi0_config\nimport openpi.models_pytorch.pi0_pytorch\nfrom openpi.training import utils\nimport openpi.training.config as _config\n\n\ndef slice_paligemma_state_dict(state_dict, config):\n    \"\"\"Convert PaliGemma JAX parameters to PyTorch format.\"\"\"\n    suffix = \"/value\" if \"img/embedding/kernel/value\" in state_dict else \"\"\n\n    # patch embeddings\n    jax_key = f\"img/embedding/kernel{suffix}\"\n    pytorch_key = \"paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight\"\n    state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1)\n\n    jax_key = f\"img/embedding/bias{suffix}\"\n    pytorch_key = \"paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias\"\n    state_dict[pytorch_key] = state_dict.pop(jax_key)\n\n    # positional embeddings\n    jax_key = f\"img/pos_embedding{suffix}\"\n    pytorch_key = \"paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight\"\n    state_dict[pytorch_key] = state_dict.pop(jax_key).reshape(-1, config.vision_config.hidden_size)\n\n    # extract vision layers to be sliced at index 0. There are 27 layers in the base model.\n    encoderblock_layernorm0_scale = state_dict.pop(f\"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}\")\n    encoderblock_layernorm0_bias = state_dict.pop(f\"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}\")\n    encoderblock_layernorm1_scale = state_dict.pop(f\"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}\")\n    encoderblock_layernorm1_bias = state_dict.pop(f\"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}\")\n\n    encoderblock_mlp_dense0_kernel = state_dict.pop(f\"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}\")\n    encoderblock_mlp_dense0_bias = state_dict.pop(f\"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}\")\n    encoderblock_mlp_dense1_kernel = state_dict.pop(f\"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}\")\n    encoderblock_mlp_dense1_bias = state_dict.pop(f\"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}\")\n\n    encoderblock_attention_0_key_kernel = state_dict.pop(\n        f\"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}\"\n    )\n    encoderblock_attention_0_key_bias = state_dict.pop(\n        f\"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}\"\n    )\n    encoderblock_attention_0_value_kernel = state_dict.pop(\n        f\"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}\"\n    )\n    encoderblock_attention_0_value_bias = state_dict.pop(\n        f\"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}\"\n    )\n    encoderblock_attention_0_query_kernel = state_dict.pop(\n        f\"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}\"\n    )\n    encoderblock_attention_0_query_bias = state_dict.pop(\n        f\"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}\"\n    )\n    encoderblock_attention_0_out_kernel = state_dict.pop(\n        f\"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}\"\n    )\n    encoderblock_attention_0_out_bias = state_dict.pop(\n        f\"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}\"\n    )\n\n    for i in range(config.vision_config.num_hidden_layers):\n        state_dict[\n            f\"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight\"\n        ] = encoderblock_layernorm0_scale[i].transpose()\n        state_dict[\n            f\"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias\"\n        ] = encoderblock_layernorm0_bias[i]\n        state_dict[\n            f\"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight\"\n        ] = encoderblock_layernorm1_scale[i].transpose()\n        state_dict[\n            f\"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias\"\n        ] = encoderblock_layernorm1_bias[i]\n        state_dict[\n            f\"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight\"\n        ] = encoderblock_mlp_dense0_kernel[i].transpose()\n        state_dict[\n            f\"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias\"\n        ] = encoderblock_mlp_dense0_bias[i]\n        state_dict[\n            f\"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight\"\n        ] = encoderblock_mlp_dense1_kernel[i].transpose()\n        state_dict[\n            f\"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias\"\n        ] = encoderblock_mlp_dense1_bias[i]\n        state_dict[\n            f\"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight\"\n        ] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()\n        state_dict[\n            f\"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias\"\n        ] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)\n        state_dict[\n            f\"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight\"\n        ] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()\n        state_dict[\n            f\"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias\"\n        ] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)\n        state_dict[\n            f\"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight\"\n        ] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()\n        state_dict[\n            f\"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias\"\n        ] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)\n        state_dict[\n            f\"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight\"\n        ] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()\n        state_dict[\n            f\"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias\"\n        ] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)\n\n    jax_key = f\"img/Transformer/encoder_norm/scale{suffix}\"\n    pytorch_key = \"paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight\"\n    state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()\n\n    jax_key = f\"img/Transformer/encoder_norm/bias{suffix}\"\n    pytorch_key = \"paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias\"\n    state_dict[pytorch_key] = state_dict.pop(jax_key)\n\n    # multimodal projector\n    jax_key = f\"img/head/kernel{suffix}\"\n    pytorch_key = \"paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight\"\n    state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()\n\n    jax_key = f\"img/head/bias{suffix}\"\n    pytorch_key = \"paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias\"\n    state_dict[pytorch_key] = state_dict.pop(jax_key)\n\n    # text decoder (gemma)\n    jax_key = f\"llm/embedder/input_embedding{suffix}\"\n    pytorch_key = \"paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight\"\n    state_dict[pytorch_key] = state_dict.pop(jax_key)\n\n    # pop the einsum attention + mlp representations\n    llm_attention_attn_vec_einsum = state_dict.pop(f\"llm/layers/attn/attn_vec_einsum/w{suffix}\")\n    llm_attention_kv_einsum = state_dict.pop(f\"llm/layers/attn/kv_einsum/w{suffix}\")\n    llm_attention_q_einsum = state_dict.pop(f\"llm/layers/attn/q_einsum/w{suffix}\")\n\n    llm_mlp_gating_einsum = state_dict.pop(f\"llm/layers/mlp/gating_einsum{suffix}\")\n    llm_mlp_linear = state_dict.pop(f\"llm/layers/mlp/linear{suffix}\")\n\n    llm_input_layernorm = state_dict.pop(f\"llm/layers/pre_attention_norm/scale{suffix}\")\n    llm_post_attention_layernorm = state_dict.pop(f\"llm/layers/pre_ffw_norm/scale{suffix}\")\n\n    for i in range(config.text_config.num_hidden_layers):\n        q_proj_weight_reshaped = (\n            llm_attention_q_einsum[i]\n            .transpose(0, 2, 1)\n            .reshape(\n                config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size\n            )\n        )\n        state_dict[f\"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight\"] = (\n            q_proj_weight_reshaped\n        )\n\n        k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()\n        state_dict[f\"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight\"] = (\n            k_proj_weight_reshaped\n        )\n        v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()\n        state_dict[f\"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight\"] = (\n            v_proj_weight_reshaped\n        )\n\n        o_proj_weight_reshaped = (\n            llm_attention_attn_vec_einsum[i]\n            .transpose(2, 0, 1)\n            .reshape(\n                config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size\n            )\n        )\n        state_dict[f\"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight\"] = (\n            o_proj_weight_reshaped\n        )\n\n        gate_proj_weight = llm_mlp_gating_einsum[i, 0]\n        state_dict[f\"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight\"] = (\n            gate_proj_weight.transpose()\n        )\n        up_proj_weight = llm_mlp_gating_einsum[i, 1]\n        state_dict[f\"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight\"] = (\n            up_proj_weight.transpose()\n        )\n        state_dict[f\"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight\"] = (\n            llm_mlp_linear[i].transpose()\n        )\n        state_dict[f\"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight\"] = (\n            llm_input_layernorm[i]\n        )\n        state_dict[\n            f\"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight\"\n        ] = llm_post_attention_layernorm[i]\n\n    jax_key = f\"llm/final_norm/scale{suffix}\"\n    pytorch_key = \"paligemma_with_expert.paligemma.model.language_model.norm.weight\"\n    state_dict[pytorch_key] = state_dict.pop(jax_key)\n\n    expert_dict = {}\n    final_state_dict = {}\n\n    # Expert-related keys to extract (including pi05 Dense layer parameters)\n    expert_keys = [\n        f\"llm/final_norm_1/scale{suffix}\",\n        f\"llm/final_norm_1/Dense_0/bias{suffix}\",\n        f\"llm/final_norm_1/Dense_0/kernel{suffix}\",\n        f\"llm/layers/attn/attn_vec_einsum_1/w{suffix}\",\n        f\"llm/layers/attn/kv_einsum_1/w{suffix}\",\n        f\"llm/layers/attn/q_einsum_1/w{suffix}\",\n        f\"llm/layers/mlp_1/gating_einsum{suffix}\",\n        f\"llm/layers/mlp_1/linear{suffix}\",\n        f\"llm/layers/pre_attention_norm_1/scale{suffix}\",\n        f\"llm/layers/pre_attention_norm_1/Dense_0/bias{suffix}\",\n        f\"llm/layers/pre_attention_norm_1/Dense_0/kernel{suffix}\",\n        f\"llm/layers/pre_ffw_norm_1/scale{suffix}\",\n        f\"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}\",\n        f\"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}\",\n    ]\n\n    for key, value in state_dict.items():\n        if key not in expert_keys:\n            final_state_dict[key] = torch.from_numpy(value)\n        else:\n            expert_dict[key] = value\n\n    return final_state_dict, expert_dict\n\n\ndef slice_gemma_state_dict(state_dict, config, *, num_expert, checkpoint_dir, pi05):\n    \"\"\"Convert Gemma JAX parameters to PyTorch format.\"\"\"\n    # Add missing attributes to config if they don't exist\n    if not hasattr(config, \"vocab_size\"):\n        config.vocab_size = 257152  # PALIGEMMA_VOCAB_SIZE\n    if not hasattr(config, \"hidden_size\"):\n        config.hidden_size = config.width\n    if not hasattr(config, \"num_hidden_layers\"):\n        config.num_hidden_layers = config.depth\n    if not hasattr(config, \"num_attention_heads\"):\n        config.num_attention_heads = config.num_heads\n\n    suffix = \"/value\" if f\"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value\" in state_dict else \"\"\n\n    llm_attention_attn_vec_einsum = state_dict.pop(f\"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}\")\n    llm_attention_kv_einsum = state_dict.pop(f\"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}\")\n    llm_attention_q_einsum = state_dict.pop(f\"llm/layers/attn/q_einsum_{num_expert}/w{suffix}\")\n\n    llm_mlp_gating_einsum = state_dict.pop(f\"llm/layers/mlp_{num_expert}/gating_einsum{suffix}\")\n    llm_mlp_linear = state_dict.pop(f\"llm/layers/mlp_{num_expert}/linear{suffix}\")\n\n    # Check if we have Dense layers (for pi05/adaptive normalization) or scale layers (for regular pi0)\n    if \"pi05\" in checkpoint_dir:\n        # Pi05 with adaptive normalization\n        llm_input_layernorm_bias = state_dict.pop(f\"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}\")\n        llm_post_attention_layernorm_bias = state_dict.pop(f\"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}\")\n        llm_input_layernorm_kernel = state_dict.pop(\n            f\"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}\"\n        )\n        llm_post_attention_layernorm_kernel = state_dict.pop(\n            f\"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}\"\n        )\n    else:\n        # Regular pi0 with standard RMSNorm\n        llm_input_layernorm = state_dict.pop(f\"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}\")\n        llm_post_attention_layernorm = state_dict.pop(f\"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}\")\n\n    for i in range(config.num_hidden_layers):\n        q_proj_weight_reshaped = (\n            llm_attention_q_einsum[i]\n            .transpose(0, 2, 1)\n            .reshape(config.num_attention_heads * config.head_dim, config.hidden_size)\n        )\n        state_dict[f\"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight\"] = (\n            q_proj_weight_reshaped\n        )\n\n        k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()\n        state_dict[f\"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight\"] = (\n            k_proj_weight_reshaped\n        )\n        v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()\n        state_dict[f\"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight\"] = (\n            v_proj_weight_reshaped\n        )\n\n        o_proj_weight_reshaped = (\n            llm_attention_attn_vec_einsum[i]\n            .reshape(config.num_attention_heads * config.head_dim, config.hidden_size)\n            .transpose(1, 0)\n        )\n        state_dict[f\"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight\"] = (\n            o_proj_weight_reshaped\n        )\n\n        gate_proj_weight = llm_mlp_gating_einsum[i, 0]\n        state_dict[f\"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight\"] = (\n            gate_proj_weight.transpose()\n        )\n        up_proj_weight = llm_mlp_gating_einsum[i, 1]\n        state_dict[f\"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight\"] = (\n            up_proj_weight.transpose()\n        )\n        state_dict[f\"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight\"] = llm_mlp_linear[\n            i\n        ].transpose()\n\n        if \"pi05\" in checkpoint_dir:\n            # Pi05 with adaptive normalization - use Dense layer parameters directly\n            state_dict[f\"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias\"] = (\n                llm_input_layernorm_bias[i]\n            )\n            state_dict[f\"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias\"] = (\n                llm_post_attention_layernorm_bias[i]\n            )\n            state_dict[f\"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight\"] = (\n                llm_input_layernorm_kernel[i].transpose()\n            )\n            state_dict[f\"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight\"] = (\n                llm_post_attention_layernorm_kernel[i].transpose()\n            )\n        else:\n            # Regular pi0 with standard RMSNorm\n            state_dict[f\"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight\"] = (\n                llm_input_layernorm[i]\n            )\n            state_dict[f\"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight\"] = (\n                llm_post_attention_layernorm[i]\n            )\n\n    # Handle final norm layer\n    if \"pi05\" in checkpoint_dir:\n        # Pi05 with adaptive normalization - use Dense layer parameters directly\n        final_norm_bias = state_dict.pop(f\"llm/final_norm_{num_expert}/Dense_0/bias{suffix}\")\n        final_norm_kernel = state_dict.pop(f\"llm/final_norm_{num_expert}/Dense_0/kernel{suffix}\")\n        state_dict[\"paligemma_with_expert.gemma_expert.model.norm.dense.bias\"] = final_norm_bias\n        state_dict[\"paligemma_with_expert.gemma_expert.model.norm.dense.weight\"] = final_norm_kernel.transpose()\n    else:\n        # Regular pi0 with standard RMSNorm\n        state_dict[\"paligemma_with_expert.gemma_expert.model.norm.weight\"] = state_dict.pop(\n            f\"llm/final_norm_{num_expert}/scale{suffix}\"\n        )\n\n        # state_dict[\"paligemma_with_expert.gemma_expert.lm_head.weight\"] = embedding_vector # weights are tied.\n\n    final_state_dict = {}\n    for key, value in state_dict.items():\n        if not isinstance(value, torch.Tensor):\n            final_state_dict[key] = torch.from_numpy(value)\n        else:\n            final_state_dict[key] = value\n\n    return final_state_dict\n\n\ndef slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str | None = None):\n    \"\"\"Load and process params by restoring via JAX model loader first.\n    This respects dtype conversions that occur during model restore.\n    \"\"\"\n    # Use repository restore utility to load a pure dict of params (value suffix removed)\n    params = openpi.models.model.restore_params(\n        f\"{checkpoint_dir}/params/\", restore_type=np.ndarray, dtype=restore_precision\n    )\n\n    return {\"paligemma_params\": traversals.flatten_mapping(params[\"PaliGemma\"], sep=\"/\"), \"projection_params\": params}\n\n\ndef load_jax_model_and_print_keys(checkpoint_dir: str):\n    \"\"\"\n    Load JAX model from checkpoint and print all parameter keys.\n\n    Args:\n        checkpoint_dir: Path to the checkpoint directory\n    \"\"\"\n    checkpoint_dir = os.path.abspath(checkpoint_dir) if not checkpoint_dir.startswith(\"gs://\") else checkpoint_dir\n    # Initialize checkpointer\n    checkpointer = ocp.PyTreeCheckpointer()\n    metadata = checkpointer.metadata(f\"{checkpoint_dir}/params\")\n    print(utils.array_tree_to_info(metadata))\n\n\ndef convert_pi0_checkpoint(\n    checkpoint_dir: str, precision: str, output_path: str, model_config: openpi.models.pi0_config.Pi0Config\n):\n    \"\"\"\n    Convert PI0 JAX checkpoint to PyTorch format.\n\n    Args:\n        checkpoint_dir: Path to the JAX checkpoint\n        precision: Model precision (float32, bfloat16, float16)\n        output_path: Path to save the converted PyTorch model\n        model_config: Model config\n    \"\"\"\n    print(f\"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}\")\n    print(f\"Model config: {model_config}\")\n\n    # Break down orbax ckpts by restoring via JAX to respect dtype\n    initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision=\"float32\")\n\n    # Process projection params\n    if model_config.pi05:\n        keys = [\n            \"action_in_proj\",\n            \"action_out_proj\",\n            \"time_mlp_in\",\n            \"time_mlp_out\",\n        ]\n    else:\n        keys = [\n            \"state_proj\",\n            \"action_in_proj\",\n            \"action_out_proj\",\n            \"action_time_mlp_in\",\n            \"action_time_mlp_out\",\n        ]\n\n    projection_params = {}\n    for key in keys:\n        kernel_params = initial_params[\"projection_params\"][key][\"kernel\"]\n        bias_params = initial_params[\"projection_params\"][key][\"bias\"]\n        if isinstance(kernel_params, dict):\n            weight = kernel_params[\"value\"]\n            bias = bias_params[\"value\"]\n        else:\n            weight = kernel_params\n            bias = bias_params\n\n        pytorch_weight_key = f\"{key}.weight\"\n        pytorch_bias_key = f\"{key}.bias\"\n\n        projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T\n        projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias))\n\n    # Create configs based on checkpoint path\n    # All models use the same PaliGemma config structure\n    class PaliGemmaConfig:\n        def __init__(self):\n            self.vision_config = type(\n                \"obj\",\n                (object,),\n                {\n                    \"hidden_size\": 1152,\n                    \"num_hidden_layers\": 27,\n                    \"num_attention_heads\": 16,\n                    \"intermediate_size\": 4304,\n                    \"patch_size\": 14,\n                    \"projection_dim\": 2048,\n                },\n            )()\n            self.text_config = type(\n                \"obj\",\n                (object,),\n                {\n                    \"hidden_size\": 2048,\n                    \"num_hidden_layers\": 18,\n                    \"num_attention_heads\": 8,\n                    \"head_dim\": 256,\n                    \"intermediate_size\": 16384,\n                },\n            )()\n\n    paligemma_config = PaliGemmaConfig()\n    action_expert_config = openpi.models.gemma.get_config(\"gemma_300m\")\n\n    # Process PaliGemma weights\n    paligemma_params, expert_params = slice_paligemma_state_dict(initial_params[\"paligemma_params\"], paligemma_config)\n\n    # Process Gemma weights from expert_params\n    gemma_params = slice_gemma_state_dict(\n        expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir, pi05=model_config.pi05\n    )\n\n    # Instantiate model\n    pi0_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_config)\n\n    # Combine all parameters (no prefix needed for our model structure)\n    all_params = {**paligemma_params, **gemma_params, **projection_params}\n\n    # Load state dict\n    pi0_model.load_state_dict(all_params, strict=False)\n\n    if precision == \"float32\":\n        pi0_model = pi0_model.to(torch.float32)\n    elif precision == \"bfloat16\":\n        pi0_model = pi0_model.to(torch.bfloat16)\n    else:\n        raise ValueError(f\"Invalid precision: {precision}\")\n\n    # Save the converted model using safetensors\n    os.makedirs(output_path, exist_ok=True)\n\n    # Save model weights as SafeTensors using save_model to handle tied weights\n    safetensors.torch.save_model(pi0_model, os.path.join(output_path, \"model.safetensors\"))\n\n    # Copy assets folder if it exists\n    assets_source = pathlib.Path(checkpoint_dir).parent / \"assets\"\n    if assets_source.exists():\n        assets_dest = pathlib.Path(output_path) / \"assets\"\n        if assets_dest.exists():\n            shutil.rmtree(assets_dest)\n        shutil.copytree(assets_source, assets_dest)\n\n    # Save config as JSON for reference\n    config_dict = {\n        \"action_dim\": model_config.action_dim,\n        \"action_horizon\": model_config.action_horizon,\n        \"paligemma_variant\": model_config.paligemma_variant,\n        \"action_expert_variant\": model_config.action_expert_variant,\n        \"precision\": precision,\n    }\n    with open(os.path.join(output_path, \"config.json\"), \"w\") as f:\n        json.dump(config_dict, f, indent=2)\n\n    print(\"Model conversion completed successfully!\")\n    print(f\"Model saved to {output_path}\")\n\n\ndef main(\n    checkpoint_dir: str,\n    config_name: str,\n    output_path: str | None = None,\n    precision: Literal[\"float32\", \"bfloat16\", \"float16\"] = \"bfloat16\",\n    *,\n    inspect_only: bool = False,\n):\n    \"\"\"Load JAX model and optionally convert to PyTorch.\n\n    Args:\n        checkpoint_dir: Path to the JAX checkpoint directory\n        output_path: Path to save converted PyTorch model (required for conversion)\n        precision: Precision for model conversion\n        inspect_only: Only inspect parameter keys, don't convert\n    \"\"\"\n    model_config = _config.get_config(config_name).model\n    if not isinstance(model_config, openpi.models.pi0_config.Pi0Config):\n        raise ValueError(f\"Config {config_name} is not a Pi0Config\")\n    if inspect_only:\n        load_jax_model_and_print_keys(checkpoint_dir)\n    else:\n        if not output_path:\n            print(\"Error: --output_path is required for conversion. Use --inspect_only to only view keys.\")\n            return\n        convert_pi0_checkpoint(checkpoint_dir, precision, output_path, model_config)\n\n\nif __name__ == \"__main__\":\n    tyro.cli(main)\n"
  },
  {
    "path": "examples/droid/README.md",
    "content": "# DROID Policies in openpi\n\nWe offer instructions for:\n- [Running inference for our best $pi_{0.5}$-DROID policy](./README.md#running-droid-inference)\n- [Running inference for other pre-trained DROID policies ($\\pi_0$, $\\pi_0$-FAST, ...)](./README.md#running-roboarena-baseline-policies)\n- [Pre-training *generalist* policies on the *full* DROID dataset](./README_train.md#training-on-droid)\n- [Fine-tuning expert $\\pi_{0.5}$ on your custom DROID dataset](./README_train.md#fine-tuning-on-custom-droid-datasets)\n\n## Running DROID Inference\n\nThis example shows how to run the fine-tuned $\\pi_{0.5}$-DROID model on the [DROID robot platform](https://github.com/droid-dataset/droid). Based on the [public RoboArena benchmark](https://robo-arena.github.io/leaderboard), this is currently our strongest generalist DROID policy. \n\n\n### Step 1: Start a policy server\n\nSince the DROID control laptop does not have a powerful GPU, we will start a remote policy server on a different machine with a more powerful GPU and then query it from the DROID control laptop during inference.\n\n1. On a machine with a powerful GPU (~NVIDIA 4090), clone and install the `openpi` repository following the instructions in the [README](https://github.com/Physical-Intelligence/openpi).\n2. Start the OpenPI server via the following command:\n\n```bash\nuv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_droid --policy.dir=gs://openpi-assets/checkpoints/pi05_droid\n```\n\nYou can also run the equivalent command below:\n\n```bash\nuv run scripts/serve_policy.py --env=DROID\n```\n\n### Step 2: Run the DROID robot\n\n1. Make sure you have the most recent version of the DROID package installed on both the DROID control laptop and the NUC.\n2. On the control laptop, activate your DROID conda environment.\n3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`.\n4. Install `tyro`, which we will use for command line parsing: `pip install tyro`.\n5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory.\n6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explorer` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).\n7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping <server_ip>` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from [\"left\", \"right\"].\n\n```bash\npython3 scripts/main.py --remote_host=<server_ip> --remote_port=<server_port> --external_camera=\"left\"\n```\n\nThe script will ask you to enter a free-form language instruction for the robot to follow. Make sure to point the cameras at the scene you want the robot to interact with. You _do not_ need to carefully control camera angle, object positions, etc. The policy is fairly robust in our experience. Happy prompting!\n\n## Troubleshooting\n\n| Issue | Solution |\n|-------|----------|\n| Cannot reach policy server | Make sure the server is running and the IP and port are correct. You can check that the server machine is reachable by running `ping <server_ip>` from the DROID laptop. |\n| Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. |\n| Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). |\n| Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) |\n\n\n## Running Other Policies\n\nWe provide configs for running the baseline DROID policies from the [RoboArena](https://robo-arena.github.io/) paper. Simply run the commands below to start inference servers for the respective policies. Then follow the instructions above to run evaluation on the DROID robot.\n\n```\n# Train from pi0-FAST, using FAST tokenizer\nuv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid\n\n# Train from pi0, using flow matching\nuv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_droid\n\n# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.\nuv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_binning_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_binning_droid\n\n# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).\nuv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_droid\n\n# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).\nuv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_specialist_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_specialist_droid\n\n# Trained from PaliGemma, using FSQ tokenizer.\nuv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_vq_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_vq_droid\n\n# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.\nuv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_diffusion_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_diffusion_droid\n```\n\nYou can find the inference configs in [roboarena_config.py](../../src/openpi/training/misc/roboarena_config.py).\n"
  },
  {
    "path": "examples/droid/README_train.md",
    "content": "# Training on DROID\n\nHere we describe how to fine-tune the pi0.5 model on the *full* DROID dataset. This is an approximate open-source reproduction of the pi05-DROID training pipeline.\n(small differences in data loading and the used action space) -- For a tutorial on how to fine-tune your model with a smaller, custom dataset collected on the DROID platform, see below.\n\nIn contrast to the rest of openpi, which uses LeRobot for data loading, we need to use RLDS as the data format for full DROID training (since at the moment LeRobot isn't scalable enough \nfor larger datasets like DROID -- they are working on improving it though). Below, we provide instructions for updating your openpi environment for RLDS data loading and where to download the DROID dataset.\n\n## Install\n\nWe need a few additional dependencies for RLDS data loading. Run:\n```bash\nuv sync --group rlds\n```\n\n## Download DROID dataset\n\nYou can download the DROID dataset with the following command (after installing the `gsutil` google cloud CLI):\n```\ngsutil -m cp -r gs://gresearch/robotics/droid/1.0.1 <your_download_path>/droid/1.0.1\n```\n\nNote that downloading version 1.0.1 is important (not v1.0.0): it contains the complete set of language annotations (~75k episodes) while v1.0.0 only has annotations for 30k episodes. If for some reason you would like to use another version, modify the line `version=\"1.0.1\"` in the `DroidRldsDataset` object [here](src/openpi/training/droid_rlds_dataset.py).\n\nYou will need 1.8TB of disk storage to download the DROID RLDS dataset.\n\n## Run\n\nFirst, change the `rlds_data_dir` path in your `TrainConfig` to the directory that you downloaded the `droid` dataset into (see [src/openpi/training/config.py](src/openpi/training/config.py)).\n\nThen, compute normalization statistics (this will take ~10 minutes):\n```bash\nuv run --group rlds scripts/compute_norm_stats.py --config-name pi05_full_droid_finetune --max-frames 10_000_000\n```\n\nRun training:\n```bash\nXLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run --group rlds scripts/train.py pi05_full_droid_finetune --exp-name=my_experiment --overwrite\n```\n\n**Note**: The original pi0.5-DROID model was trained with joint velocity actions.\nJoint velocity actions are not compatible with simulated evaluation environments (much harder to simulate). \nThus, we do not recommend training with joint velocity actions and instead use joint position actions here.\n\n\n## Compute Requirements\n\nOur DROID training config requires approximately 2 days on 8x H100 GPUs for convergence (100k iterations, bs256, approx. 1 epoch).\nIf you start from PaliGemma instead of pi0 initialization, plan with ~5 days on 8x H100s (240k iterations, i.e. 3 epochs).\n\nWe have experimented with LoRA for cheaper finetuning, but haven't found the policies to perform well so far.\n\n\n## Data Filtering\n\nLike any diverse real-robot dataset, the DROID dataset isn't perfectly \"clean\" and we have found data filtering to significantly improve policy performance. Concretely, the DROID dataset contains many *idle* timesteps in which the robot does not move (in part due to the VR teleoperation interface that was used during data collection, we will not go into too much detail here). Appropriate filtering of these idle transitions can improve policy performance.\n\nBy default, our openpi training recipe implements the same idle filter used to train all pi-DROID models. We implement it by pre-computing which dataset indices to sample during training. You can check [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) for how we compute these indices. Roughly speaking, we filter any time steps for which the next chunk of actions would be largely idle. During training, our code automatically pulls our pre-computed list of indices from cloud storage and applies them. If you want to modify the idle filter / create your custom sampling logic, you can modify our script to generate a new index list and provide it via the `filter_dict_path=\"<path_to_filter_dict>\"` argument in [src/openpi/training/config.py](src/openpi/training/config.py).\n\n**Note**: our list of filtering indices is only valid for the `droid/1.0.1` dataset mentioned in the download section above, and will not provide valid filtering for any other version of the DROID dataset, so make sure you download the dataset above! If you have a custom DROID version, you can rerun the [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) script to generate a new list of sampling indices.\n\n## RoboArena\n\nConsider submitting your DROID policies to the [RoboArena benchmark](https://robo-arena.github.io/), which allows you to evaluate your policies on diverse tasks & scenes, **in the real world**! :)\n\nIf you have questions about RoboArena, please email [karl.pertsch@gmail.com](mailto:karl.pertsch@gmail.com).\n\n\n# Fine-Tuning on Custom DROID Datasets\n\nHere we describe how to fine-tune a model on a custom (smaller) dataset collected on the DROID platform. Like for other datasets, we will first convert the custom DROID dataset to LeRobot and then fine-tune a model (pi05-droid) on it.\n\nNote: We use LeRobot here, since we assume the custom DROID fine-tuning dataset to be relatively small (<10s of hours). For larger datasets (like the full DROID dataset) we recommend using RLDS for it's better efficiency (see the example above).\n\n\n## Step 1: Converting your custom DROID dataset to LeRobot\n\nWe will use a small subset of the real DROID dataset for this example. This is a subset of just 30 demonstrations -- we assume that you will use your own dataset instead, but here is the command to download our subset (1.6GB):\n```\ngsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04 <your_target_path>\n```\n\nWe will also download the language annotations for the DROID dataset so we can pair our demonstrations with language instructions. Again, for your own data you can manually enter your language instructions and don't need to download our annotations. To download the DROID language annotations (12MB), run:\n```\ngsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/aggregated-annotations-030724.json <your_target_dir>\n```\n\nFor your own dataset, make sure that each episode's directory contains a folder called `recordings/MP4` -- if not, you need to first run the MP4 video extraction (from SVO files) using the script [here](https://github.com/droid-dataset/droid/blob/main/scripts/convert/svo_to_mp4.py).\n\nNow, we will use the `convert_droid_to_lerobot.py` script to create a LeRobot version of this dataset (takes <5min for the 30 demonstrations):\n```\nuv run examples/droid/convert_droid_data_to_lerobot.py --data_dir <your_target_path>\n```\n\n## Step 2: Run fine-tuning with your custom dataset\n\nNow we can run fine-tuning with our converted custom dataset. We provide an example config for fine-tuning `pi05_droid` on the custom dataset we created. \nYou can modify the config easily to work with other base models, or use your custom DROID dataset in `config.py` (seach for `pi05_droid_finetune`).\n\nTo launch training:\n```\nuv run scripts/train.py pi05_droid_finetune --exp-name=my_experiment --overwrite\n```\n\nOnce trained, you can follow the instructions in [`examples/droid/README.md`](examples/droid/README.md) to serve the policy and run it on the robot.\n\n"
  },
  {
    "path": "examples/droid/compute_droid_nonidle_ranges.py",
    "content": "\"\"\"\nIterates through the DROID dataset and creates a json mapping from episode unique IDs to ranges of time steps\nthat should be sampled during training (all others are filtered out).\n\nFiltering logic:\nWe look for ranges of consecutive steps that contain at most min_idle_len consecutive idle frames\n(default to 7 -- as most DROID action-chunking policies run the first 8 actions generated in each chunk, filtering\nthis way means the policy will not get stuck outputting stationary actions). Additionally, we also only keep non-idle\nranges of length at least min_non_idle_len (default to 16 frames = ~1 second), while also removing the last\nfilter_last_n_in_ranges frames from the end of each range (as those all correspond to action chunks with many idle actions).\n\nThis leaves us with trajectory segments consisting of contiguous, significant movement. Training on this filtered set\nyields policies that output fewer stationary actions (i.e., get \"stuck\" in states less).\n\"\"\"\n\nimport json\nimport os\nfrom pathlib import Path\n\nimport numpy as np\nimport tensorflow as tf\nimport tensorflow_datasets as tfds\nfrom tqdm import tqdm\n\nos.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"  # Set to the GPU you want to use, or leave empty for CPU\n\nbuilder = tfds.builder_from_directory(\n    # path to the `droid` directory (not its parent)\n    builder_dir=\"<path_to_droid_dataset_tfds_files>\",\n)\nds = builder.as_dataset(split=\"train\", shuffle_files=False)\ntf.data.experimental.ignore_errors(ds)\n\nkeep_ranges_path = \"<path_to_where_to_save_the_json>\"\n\nmin_idle_len = 7  # If more than this number of consecutive idle frames, filter all of them out\nmin_non_idle_len = 16  # If fewer than this number of consecutive non-idle frames, filter all of them out\nfilter_last_n_in_ranges = 10  # When using a filter dict, remove this many frames from the end of each range\n\nkeep_ranges_map = {}\nif Path(keep_ranges_path).exists():\n    with Path(keep_ranges_path).open(\"r\") as f:\n        keep_ranges_map = json.load(f)\n    print(f\"Resuming from {len(keep_ranges_map)} episodes already processed\")\n\nfor ep_idx, ep in enumerate(tqdm(ds)):\n    recording_folderpath = ep[\"episode_metadata\"][\"recording_folderpath\"].numpy().decode()\n    file_path = ep[\"episode_metadata\"][\"file_path\"].numpy().decode()\n\n    key = f\"{recording_folderpath}--{file_path}\"\n    if key in keep_ranges_map:\n        continue\n\n    joint_velocities = [step[\"action_dict\"][\"joint_velocity\"].numpy() for step in ep[\"steps\"]]\n    joint_velocities = np.array(joint_velocities)\n\n    is_idle_array = np.hstack(\n        [np.array([False]), np.all(np.abs(joint_velocities[1:] - joint_velocities[:-1]) < 1e-3, axis=1)]\n    )\n\n    # Find what steps go from idle to non-idle and vice-versa\n    is_idle_padded = np.concatenate(\n        [[False], is_idle_array, [False]]\n    )  # Start and end with False, so idle at first step is a start of motion\n\n    is_idle_diff = np.diff(is_idle_padded.astype(int))\n    is_idle_true_starts = np.where(is_idle_diff == 1)[0]  # +1 transitions --> going from idle to non-idle\n    is_idle_true_ends = np.where(is_idle_diff == -1)[0]  # -1 transitions --> going from non-idle to idle\n\n    # Find which steps correspond to idle segments of length at least min_idle_len\n    true_segment_masks = (is_idle_true_ends - is_idle_true_starts) >= min_idle_len\n    is_idle_true_starts = is_idle_true_starts[true_segment_masks]\n    is_idle_true_ends = is_idle_true_ends[true_segment_masks]\n\n    keep_mask = np.ones(len(joint_velocities), dtype=bool)\n    for start, end in zip(is_idle_true_starts, is_idle_true_ends, strict=True):\n        keep_mask[start:end] = False\n\n    # Get all non-idle ranges of at least 16\n    # Same logic as above, but for keep_mask, allowing us to filter out contiguous ranges of length < min_non_idle_len\n    keep_padded = np.concatenate([[False], keep_mask, [False]])\n\n    keep_diff = np.diff(keep_padded.astype(int))\n    keep_true_starts = np.where(keep_diff == 1)[0]  # +1 transitions --> going from filter out to keep\n    keep_true_ends = np.where(keep_diff == -1)[0]  # -1 transitions --> going from keep to filter out\n\n    # Find which steps correspond to non-idle segments of length at least min_non_idle_len\n    true_segment_masks = (keep_true_ends - keep_true_starts) >= min_non_idle_len\n    keep_true_starts = keep_true_starts[true_segment_masks]\n    keep_true_ends = keep_true_ends[true_segment_masks]\n\n    # Add mapping from episode unique ID key to list of non-idle ranges to keep\n    keep_ranges_map[key] = []\n    for start, end in zip(keep_true_starts, keep_true_ends, strict=True):\n        keep_ranges_map[key].append((int(start), int(end) - filter_last_n_in_ranges))\n\n    if ep_idx % 1000 == 0:\n        with Path(keep_ranges_path).open(\"w\") as f:\n            json.dump(keep_ranges_map, f)\n\nprint(\"Done!\")\nwith Path(keep_ranges_path).open(\"w\") as f:\n    json.dump(keep_ranges_map, f)\n"
  },
  {
    "path": "examples/droid/convert_droid_data_to_lerobot.py",
    "content": "\"\"\"\nMinimal example script for converting a dataset collected on the DROID platform to LeRobot format.\n\nUsage:\nuv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data\n\nIf you want to push your dataset to the Hugging Face Hub, you can use the following command:\nuv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub\n\nThe resulting dataset will get saved to the $LEROBOT_HOME directory.\n\"\"\"\n\nfrom collections import defaultdict\nimport copy\nimport glob\nimport json\nfrom pathlib import Path\nimport shutil\n\nimport cv2\nimport h5py\nfrom lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME\nfrom lerobot.common.datasets.lerobot_dataset import LeRobotDataset\nimport numpy as np\nfrom PIL import Image\nfrom tqdm import tqdm\nimport tyro\n\nREPO_NAME = \"your_hf_username/my_droid_dataset\"  # Name of the output dataset, also used for the Hugging Face Hub\n\n\ndef resize_image(image, size):\n    image = Image.fromarray(image)\n    return np.array(image.resize(size, resample=Image.BICUBIC))\n\n\ndef main(data_dir: str, *, push_to_hub: bool = False):\n    # Clean up any existing dataset in the output directory\n    output_path = HF_LEROBOT_HOME / REPO_NAME\n    if output_path.exists():\n        shutil.rmtree(output_path)\n    data_dir = Path(data_dir)\n\n    # Create LeRobot dataset, define features to store\n    # We will follow the DROID data naming conventions here.\n    # LeRobot assumes that dtype of image data is `image`\n    dataset = LeRobotDataset.create(\n        repo_id=REPO_NAME,\n        robot_type=\"panda\",\n        fps=15,  # DROID data is typically recorded at 15fps\n        features={\n            # We call this \"left\" since we will only use the left stereo camera (following DROID RLDS convention)\n            \"exterior_image_1_left\": {\n                \"dtype\": \"image\",\n                \"shape\": (180, 320, 3),  # This is the resolution used in the DROID RLDS dataset\n                \"names\": [\"height\", \"width\", \"channel\"],\n            },\n            \"exterior_image_2_left\": {\n                \"dtype\": \"image\",\n                \"shape\": (180, 320, 3),\n                \"names\": [\"height\", \"width\", \"channel\"],\n            },\n            \"wrist_image_left\": {\n                \"dtype\": \"image\",\n                \"shape\": (180, 320, 3),\n                \"names\": [\"height\", \"width\", \"channel\"],\n            },\n            \"joint_position\": {\n                \"dtype\": \"float32\",\n                \"shape\": (7,),\n                \"names\": [\"joint_position\"],\n            },\n            \"gripper_position\": {\n                \"dtype\": \"float32\",\n                \"shape\": (1,),\n                \"names\": [\"gripper_position\"],\n            },\n            \"actions\": {\n                \"dtype\": \"float32\",\n                \"shape\": (8,),  # We will use joint *velocity* actions here (7D) + gripper position (1D)\n                \"names\": [\"actions\"],\n            },\n        },\n        image_writer_threads=10,\n        image_writer_processes=5,\n    )\n\n    # Load language annotations\n    # Note: we load the DROID language annotations for this example, but you can manually define them for your own data\n    with (data_dir / \"aggregated-annotations-030724.json\").open() as f:\n        language_annotations = json.load(f)\n\n    # Loop over raw DROID fine-tuning datasets and write episodes to the LeRobot dataset\n    # We assume the following directory structure:\n    # RAW_DROID_PATH/\n    #   - <...>/\n    #     - recordings/\n    #        - MP4/\n    #          - <camera_id>.mp4  # single-view video of left stereo pair camera\n    #     - trajectory.hdf5\n    #   - <...>/\n    episode_paths = list(data_dir.glob(\"**/trajectory.h5\"))\n    print(f\"Found {len(episode_paths)} episodes for conversion\")\n\n    # We will loop over each dataset_name and write episodes to the LeRobot dataset\n    for episode_path in tqdm(episode_paths, desc=\"Converting episodes\"):\n        # Load raw data\n        recording_folderpath = episode_path.parent / \"recordings\" / \"MP4\"\n        trajectory = load_trajectory(str(episode_path), recording_folderpath=str(recording_folderpath))\n\n        # To load the language instruction, we need to parse out the episode_id from the metadata file\n        # Again, you can modify this step for your own data, to load your own language instructions\n        metadata_filepath = next(iter(episode_path.parent.glob(\"metadata_*.json\")))\n        episode_id = metadata_filepath.name.split(\".\")[0].split(\"_\")[-1]\n        language_instruction = language_annotations.get(episode_id, {\"language_instruction1\": \"Do something\"})[\n            \"language_instruction1\"\n        ]\n        print(f\"Converting episode with language instruction: {language_instruction}\")\n\n        # Write to LeRobot dataset\n        for step in trajectory:\n            camera_type_dict = step[\"observation\"][\"camera_type\"]\n            wrist_ids = [k for k, v in camera_type_dict.items() if v == 0]\n            exterior_ids = [k for k, v in camera_type_dict.items() if v != 0]\n            dataset.add_frame(\n                {\n                    # Note: need to flip BGR --> RGB for loaded images\n                    \"exterior_image_1_left\": resize_image(\n                        step[\"observation\"][\"image\"][exterior_ids[0]][..., ::-1], (320, 180)\n                    ),\n                    \"exterior_image_2_left\": resize_image(\n                        step[\"observation\"][\"image\"][exterior_ids[1]][..., ::-1], (320, 180)\n                    ),\n                    \"wrist_image_left\": resize_image(step[\"observation\"][\"image\"][wrist_ids[0]][..., ::-1], (320, 180)),\n                    \"joint_position\": np.asarray(\n                        step[\"observation\"][\"robot_state\"][\"joint_positions\"], dtype=np.float32\n                    ),\n                    \"gripper_position\": np.asarray(\n                        step[\"observation\"][\"robot_state\"][\"gripper_position\"][None], dtype=np.float32\n                    ),\n                    # Important: we use joint velocity actions here since pi05-droid was pre-trained on joint velocity actions\n                    \"actions\": np.concatenate(\n                        [step[\"action\"][\"joint_velocity\"], step[\"action\"][\"gripper_position\"][None]], dtype=np.float32\n                    ),\n                    \"task\": language_instruction,\n                }\n            )\n        dataset.save_episode()\n\n    # Optionally push to the Hugging Face Hub\n    if push_to_hub:\n        dataset.push_to_hub(\n            tags=[\"libero\", \"panda\", \"rlds\"],\n            private=False,\n            push_videos=True,\n            license=\"apache-2.0\",\n        )\n\n\n##########################################################################################################\n################ The rest of this file are functions to parse the raw DROID data #########################\n################ You don't need to worry about understanding this part           #########################\n################ It was copied from here: https://github.com/JonathanYang0127/r2d2_rlds_dataset_builder/blob/parallel_convert/r2_d2/r2_d2.py\n##########################################################################################################\n\n\ncamera_type_dict = {\n    \"hand_camera_id\": 0,\n    \"varied_camera_1_id\": 1,\n    \"varied_camera_2_id\": 1,\n}\n\ncamera_type_to_string_dict = {\n    0: \"hand_camera\",\n    1: \"varied_camera\",\n    2: \"fixed_camera\",\n}\n\n\ndef get_camera_type(cam_id):\n    if cam_id not in camera_type_dict:\n        return None\n    type_int = camera_type_dict[cam_id]\n    return camera_type_to_string_dict[type_int]\n\n\nclass MP4Reader:\n    def __init__(self, filepath, serial_number):\n        # Save Parameters #\n        self.serial_number = serial_number\n        self._index = 0\n\n        # Open Video Reader #\n        self._mp4_reader = cv2.VideoCapture(filepath)\n        if not self._mp4_reader.isOpened():\n            raise RuntimeError(\"Corrupted MP4 File\")\n\n    def set_reading_parameters(\n        self,\n        image=True,  # noqa: FBT002\n        concatenate_images=False,  # noqa: FBT002\n        resolution=(0, 0),\n        resize_func=None,\n    ):\n        # Save Parameters #\n        self.image = image\n        self.concatenate_images = concatenate_images\n        self.resolution = resolution\n        self.resize_func = cv2.resize\n        self.skip_reading = not image\n        if self.skip_reading:\n            return\n\n    def get_frame_resolution(self):\n        width = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH)\n        height = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT)\n        return (width, height)\n\n    def get_frame_count(self):\n        if self.skip_reading:\n            return 0\n        return int(self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT))\n\n    def set_frame_index(self, index):\n        if self.skip_reading:\n            return\n\n        if index < self._index:\n            self._mp4_reader.set(cv2.CAP_PROP_POS_FRAMES, index - 1)\n            self._index = index\n\n        while self._index < index:\n            self.read_camera(ignore_data=True)\n\n    def _process_frame(self, frame):\n        frame = copy.deepcopy(frame)\n        if self.resolution == (0, 0):\n            return frame\n        return self.resize_func(frame, self.resolution)\n\n    def read_camera(self, ignore_data=False, correct_timestamp=None):  # noqa: FBT002\n        # Skip if Read Unnecessary #\n        if self.skip_reading:\n            return {}\n\n        # Read Camera #\n        success, frame = self._mp4_reader.read()\n\n        self._index += 1\n        if not success:\n            return None\n        if ignore_data:\n            return None\n\n        # Return Data #\n        data_dict = {}\n\n        if self.concatenate_images or \"stereo\" not in self.serial_number:\n            data_dict[\"image\"] = {self.serial_number: self._process_frame(frame)}\n        else:\n            single_width = frame.shape[1] // 2\n            data_dict[\"image\"] = {\n                self.serial_number + \"_left\": self._process_frame(frame[:, :single_width, :]),\n                self.serial_number + \"_right\": self._process_frame(frame[:, single_width:, :]),\n            }\n\n        return data_dict\n\n    def disable_camera(self):\n        if hasattr(self, \"_mp4_reader\"):\n            self._mp4_reader.release()\n\n\nclass RecordedMultiCameraWrapper:\n    def __init__(self, recording_folderpath, camera_kwargs={}):  # noqa: B006\n        # Save Camera Info #\n        self.camera_kwargs = camera_kwargs\n\n        # Open Camera Readers #\n        mp4_filepaths = glob.glob(recording_folderpath + \"/*.mp4\")\n        all_filepaths = mp4_filepaths\n\n        self.camera_dict = {}\n        for f in all_filepaths:\n            serial_number = f.split(\"/\")[-1][:-4]\n            cam_type = get_camera_type(serial_number)\n            camera_kwargs.get(cam_type, {})\n\n            if f.endswith(\".mp4\"):\n                Reader = MP4Reader  # noqa: N806\n            else:\n                raise ValueError\n\n            self.camera_dict[serial_number] = Reader(f, serial_number)\n\n    def read_cameras(self, index=None, camera_type_dict={}, timestamp_dict={}):  # noqa: B006\n        full_obs_dict = defaultdict(dict)\n\n        # Read Cameras In Randomized Order #\n        all_cam_ids = list(self.camera_dict.keys())\n        # random.shuffle(all_cam_ids)\n\n        for cam_id in all_cam_ids:\n            if \"stereo\" in cam_id:\n                continue\n            try:\n                cam_type = camera_type_dict[cam_id]\n            except KeyError:\n                print(f\"{self.camera_dict} -- {camera_type_dict}\")\n                raise ValueError(f\"Camera type {cam_id} not found in camera_type_dict\")  # noqa: B904\n            curr_cam_kwargs = self.camera_kwargs.get(cam_type, {})\n            self.camera_dict[cam_id].set_reading_parameters(**curr_cam_kwargs)\n\n            timestamp = timestamp_dict.get(cam_id + \"_frame_received\", None)\n            if index is not None:\n                self.camera_dict[cam_id].set_frame_index(index)\n\n            data_dict = self.camera_dict[cam_id].read_camera(correct_timestamp=timestamp)\n\n            # Process Returned Data #\n            if data_dict is None:\n                return None\n            for key in data_dict:\n                full_obs_dict[key].update(data_dict[key])\n\n        return full_obs_dict\n\n\ndef get_hdf5_length(hdf5_file, keys_to_ignore=[]):  # noqa: B006\n    length = None\n\n    for key in hdf5_file:\n        if key in keys_to_ignore:\n            continue\n\n        curr_data = hdf5_file[key]\n        if isinstance(curr_data, h5py.Group):\n            curr_length = get_hdf5_length(curr_data, keys_to_ignore=keys_to_ignore)\n        elif isinstance(curr_data, h5py.Dataset):\n            curr_length = len(curr_data)\n        else:\n            raise ValueError\n\n        if length is None:\n            length = curr_length\n        assert curr_length == length\n\n    return length\n\n\ndef load_hdf5_to_dict(hdf5_file, index, keys_to_ignore=[]):  # noqa: B006\n    data_dict = {}\n\n    for key in hdf5_file:\n        if key in keys_to_ignore:\n            continue\n\n        curr_data = hdf5_file[key]\n        if isinstance(curr_data, h5py.Group):\n            data_dict[key] = load_hdf5_to_dict(curr_data, index, keys_to_ignore=keys_to_ignore)\n        elif isinstance(curr_data, h5py.Dataset):\n            data_dict[key] = curr_data[index]\n        else:\n            raise ValueError\n\n    return data_dict\n\n\nclass TrajectoryReader:\n    def __init__(self, filepath, read_images=True):  # noqa: FBT002\n        self._hdf5_file = h5py.File(filepath, \"r\")\n        is_video_folder = \"observations/videos\" in self._hdf5_file\n        self._read_images = read_images and is_video_folder\n        self._length = get_hdf5_length(self._hdf5_file)\n        self._video_readers = {}\n        self._index = 0\n\n    def length(self):\n        return self._length\n\n    def read_timestep(self, index=None, keys_to_ignore=[]):  # noqa: B006\n        # Make Sure We Read Within Range #\n        if index is None:\n            index = self._index\n        else:\n            assert not self._read_images\n            self._index = index\n        assert index < self._length\n\n        # Load Low Dimensional Data #\n        keys_to_ignore = [*keys_to_ignore.copy(), \"videos\"]\n        timestep = load_hdf5_to_dict(self._hdf5_file, self._index, keys_to_ignore=keys_to_ignore)\n\n        # Increment Read Index #\n        self._index += 1\n\n        # Return Timestep #\n        return timestep\n\n    def close(self):\n        self._hdf5_file.close()\n\n\ndef load_trajectory(\n    filepath=None,\n    read_cameras=True,  # noqa: FBT002\n    recording_folderpath=None,\n    camera_kwargs={},  # noqa: B006\n    remove_skipped_steps=False,  # noqa: FBT002\n    num_samples_per_traj=None,\n    num_samples_per_traj_coeff=1.5,\n):\n    read_recording_folderpath = read_cameras and (recording_folderpath is not None)\n\n    traj_reader = TrajectoryReader(filepath)\n    if read_recording_folderpath:\n        camera_reader = RecordedMultiCameraWrapper(recording_folderpath, camera_kwargs)\n\n    horizon = traj_reader.length()\n    timestep_list = []\n\n    # Choose Timesteps To Save #\n    if num_samples_per_traj:\n        num_to_save = num_samples_per_traj\n        if remove_skipped_steps:\n            num_to_save = int(num_to_save * num_samples_per_traj_coeff)\n        max_size = min(num_to_save, horizon)\n        indices_to_save = np.sort(np.random.choice(horizon, size=max_size, replace=False))\n    else:\n        indices_to_save = np.arange(horizon)\n\n    # Iterate Over Trajectory #\n    for i in indices_to_save:\n        # Get HDF5 Data #\n        timestep = traj_reader.read_timestep(index=i)\n\n        # If Applicable, Get Recorded Data #\n        if read_recording_folderpath:\n            timestamp_dict = timestep[\"observation\"][\"timestamp\"][\"cameras\"]\n            camera_type_dict = {\n                k: camera_type_to_string_dict[v] for k, v in timestep[\"observation\"][\"camera_type\"].items()\n            }\n            camera_obs = camera_reader.read_cameras(\n                index=i, camera_type_dict=camera_type_dict, timestamp_dict=timestamp_dict\n            )\n            camera_failed = camera_obs is None\n\n            # Add Data To Timestep If Successful #\n            if camera_failed:\n                break\n            timestep[\"observation\"].update(camera_obs)\n\n        # Filter Steps #\n        step_skipped = not timestep[\"observation\"][\"controller_info\"].get(\"movement_enabled\", True)\n        delete_skipped_step = step_skipped and remove_skipped_steps\n\n        # Save Filtered Timesteps #\n        if delete_skipped_step:\n            del timestep\n        else:\n            timestep_list.append(timestep)\n\n    # Remove Extra Transitions #\n    timestep_list = np.array(timestep_list)\n    if (num_samples_per_traj is not None) and (len(timestep_list) > num_samples_per_traj):\n        ind_to_keep = np.random.choice(len(timestep_list), size=num_samples_per_traj, replace=False)\n        timestep_list = timestep_list[ind_to_keep]\n\n    # Close Readers #\n    traj_reader.close()\n\n    # Return Data #\n    return timestep_list\n\n\nif __name__ == \"__main__\":\n    tyro.cli(main)\n"
  },
  {
    "path": "examples/droid/main.py",
    "content": "# ruff: noqa\n\nimport contextlib\nimport dataclasses\nimport datetime\nimport faulthandler\nimport os\nimport signal\nimport time\nfrom moviepy.editor import ImageSequenceClip\nimport numpy as np\nfrom openpi_client import image_tools\nfrom openpi_client import websocket_client_policy\nimport pandas as pd\nfrom PIL import Image\nfrom droid.robot_env import RobotEnv\nimport tqdm\nimport tyro\n\nfaulthandler.enable()\n\n# DROID data collection frequency -- we slow down execution to match this frequency\nDROID_CONTROL_FREQUENCY = 15\n\n\n@dataclasses.dataclass\nclass Args:\n    # Hardware parameters\n    left_camera_id: str = \"<your_camera_id>\"  # e.g., \"24259877\"\n    right_camera_id: str = \"<your_camera_id>\"  # e.g., \"24514023\"\n    wrist_camera_id: str = \"<your_camera_id>\"  # e.g., \"13062452\"\n\n    # Policy parameters\n    external_camera: str | None = (\n        None  # which external camera should be fed to the policy, choose from [\"left\", \"right\"]\n    )\n\n    # Rollout parameters\n    max_timesteps: int = 600\n    # How many actions to execute from a predicted action chunk before querying policy server again\n    # 8 is usually a good default (equals 0.5 seconds of action execution).\n    open_loop_horizon: int = 8\n\n    # Remote server parameters\n    remote_host: str = \"0.0.0.0\"  # point this to the IP address of the policy server, e.g., \"192.168.1.100\"\n    remote_port: int = (\n        8000  # point this to the port of the policy server, default server port for openpi servers is 8000\n    )\n\n\n# We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is\n# waiting for a new action chunk, it will raise an exception and the server connection dies.\n# This context manager temporarily prevents Ctrl+C and delays it after the server call is complete.\n@contextlib.contextmanager\ndef prevent_keyboard_interrupt():\n    \"\"\"Temporarily prevent keyboard interrupts by delaying them until after the protected code.\"\"\"\n    interrupted = False\n    original_handler = signal.getsignal(signal.SIGINT)\n\n    def handler(signum, frame):\n        nonlocal interrupted\n        interrupted = True\n\n    signal.signal(signal.SIGINT, handler)\n    try:\n        yield\n    finally:\n        signal.signal(signal.SIGINT, original_handler)\n        if interrupted:\n            raise KeyboardInterrupt\n\n\ndef main(args: Args):\n    # Make sure external camera is specified by user -- we only use one external camera for the policy\n    assert (\n        args.external_camera is not None and args.external_camera in [\"left\", \"right\"]\n    ), f\"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}\"\n\n    # Initialize the Panda environment. Using joint velocity action space and gripper position action space is very important.\n    env = RobotEnv(action_space=\"joint_velocity\", gripper_action_space=\"position\")\n    print(\"Created the droid env!\")\n\n    # Connect to the policy server\n    policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port)\n\n    df = pd.DataFrame(columns=[\"success\", \"duration\", \"video_filename\"])\n\n    while True:\n        instruction = input(\"Enter instruction: \")\n\n        # Rollout parameters\n        actions_from_chunk_completed = 0\n        pred_action_chunk = None\n\n        # Prepare to save video of rollout\n        timestamp = datetime.datetime.now().strftime(\"%Y_%m_%d_%H:%M:%S\")\n        video = []\n        bar = tqdm.tqdm(range(args.max_timesteps))\n        print(\"Running rollout... press Ctrl+C to stop early.\")\n        for t_step in bar:\n            start_time = time.time()\n            try:\n                # Get the current observation\n                curr_obs = _extract_observation(\n                    args,\n                    env.get_observation(),\n                    # Save the first observation to disk\n                    save_to_disk=t_step == 0,\n                )\n\n                video.append(curr_obs[f\"{args.external_camera}_image\"])\n\n                # Send websocket request to policy server if it's time to predict a new chunk\n                if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon:\n                    actions_from_chunk_completed = 0\n\n                    # We resize images on the robot laptop to minimize the amount of data sent to the policy server\n                    # and improve latency.\n                    request_data = {\n                        \"observation/exterior_image_1_left\": image_tools.resize_with_pad(\n                            curr_obs[f\"{args.external_camera}_image\"], 224, 224\n                        ),\n                        \"observation/wrist_image_left\": image_tools.resize_with_pad(curr_obs[\"wrist_image\"], 224, 224),\n                        \"observation/joint_position\": curr_obs[\"joint_position\"],\n                        \"observation/gripper_position\": curr_obs[\"gripper_position\"],\n                        \"prompt\": instruction,\n                    }\n\n                    # Wrap the server call in a context manager to prevent Ctrl+C from interrupting it\n                    # Ctrl+C will be handled after the server call is complete\n                    with prevent_keyboard_interrupt():\n                        # this returns action chunk [10, 8] of 10 joint velocity actions (7) + gripper position (1)\n                        pred_action_chunk = policy_client.infer(request_data)[\"actions\"]\n                    assert pred_action_chunk.shape == (10, 8)\n\n                # Select current action to execute from chunk\n                action = pred_action_chunk[actions_from_chunk_completed]\n                actions_from_chunk_completed += 1\n\n                # Binarize gripper action\n                if action[-1].item() > 0.5:\n                    # action[-1] = 1.0\n                    action = np.concatenate([action[:-1], np.ones((1,))])\n                else:\n                    # action[-1] = 0.0\n                    action = np.concatenate([action[:-1], np.zeros((1,))])\n\n                # clip all dimensions of action to [-1, 1]\n                action = np.clip(action, -1, 1)\n\n                env.step(action)\n\n                # Sleep to match DROID data collection frequency\n                elapsed_time = time.time() - start_time\n                if elapsed_time < 1 / DROID_CONTROL_FREQUENCY:\n                    time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time)\n            except KeyboardInterrupt:\n                break\n\n        video = np.stack(video)\n        save_filename = \"video_\" + timestamp\n        ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + \".mp4\", codec=\"libx264\")\n\n        success: str | float | None = None\n        while not isinstance(success, float):\n            success = input(\n                \"Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec\"\n            )\n            if success == \"y\":\n                success = 1.0\n            elif success == \"n\":\n                success = 0.0\n\n            success = float(success) / 100\n            if not (0 <= success <= 1):\n                print(f\"Success must be a number in [0, 100] but got: {success * 100}\")\n\n        df = df.append(\n            {\n                \"success\": success,\n                \"duration\": t_step,\n                \"video_filename\": save_filename,\n            },\n            ignore_index=True,\n        )\n\n        if input(\"Do one more eval? (enter y or n) \").lower() != \"y\":\n            break\n        env.reset()\n\n    os.makedirs(\"results\", exist_ok=True)\n    timestamp = datetime.datetime.now().strftime(\"%I:%M%p_%B_%d_%Y\")\n    csv_filename = os.path.join(\"results\", f\"eval_{timestamp}.csv\")\n    df.to_csv(csv_filename)\n    print(f\"Results saved to {csv_filename}\")\n\n\ndef _extract_observation(args: Args, obs_dict, *, save_to_disk=False):\n    image_observations = obs_dict[\"image\"]\n    left_image, right_image, wrist_image = None, None, None\n    for key in image_observations:\n        # Note the \"left\" below refers to the left camera in the stereo pair.\n        # The model is only trained on left stereo cams, so we only feed those.\n        if args.left_camera_id in key and \"left\" in key:\n            left_image = image_observations[key]\n        elif args.right_camera_id in key and \"left\" in key:\n            right_image = image_observations[key]\n        elif args.wrist_camera_id in key and \"left\" in key:\n            wrist_image = image_observations[key]\n\n    # Drop the alpha dimension\n    left_image = left_image[..., :3]\n    right_image = right_image[..., :3]\n    wrist_image = wrist_image[..., :3]\n\n    # Convert to RGB\n    left_image = left_image[..., ::-1]\n    right_image = right_image[..., ::-1]\n    wrist_image = wrist_image[..., ::-1]\n\n    # In addition to image observations, also capture the proprioceptive state\n    robot_state = obs_dict[\"robot_state\"]\n    cartesian_position = np.array(robot_state[\"cartesian_position\"])\n    joint_position = np.array(robot_state[\"joint_positions\"])\n    gripper_position = np.array([robot_state[\"gripper_position\"]])\n\n    # Save the images to disk so that they can be viewed live while the robot is running\n    # Create one combined image to make live viewing easy\n    if save_to_disk:\n        combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1)\n        combined_image = Image.fromarray(combined_image)\n        combined_image.save(\"robot_camera_views.png\")\n\n    return {\n        \"left_image\": left_image,\n        \"right_image\": right_image,\n        \"wrist_image\": wrist_image,\n        \"cartesian_position\": cartesian_position,\n        \"joint_position\": joint_position,\n        \"gripper_position\": gripper_position,\n    }\n\n\nif __name__ == \"__main__\":\n    args: Args = tyro.cli(Args)\n    main(args)\n"
  },
  {
    "path": "examples/inference.ipynb",
    "content": "{\n    \"cells\": [\n        {\n            \"cell_type\": \"code\",\n            \"execution_count\": 1,\n            \"metadata\": {},\n            \"outputs\": [],\n            \"source\": [\n                \"import dataclasses\\n\",\n                \"\\n\",\n                \"import jax\\n\",\n                \"\\n\",\n                \"from openpi.models import model as _model\\n\",\n                \"from openpi.policies import droid_policy\\n\",\n                \"from openpi.policies import policy_config as _policy_config\\n\",\n                \"from openpi.shared import download\\n\",\n                \"from openpi.training import config as _config\\n\",\n                \"from openpi.training import data_loader as _data_loader\"\n            ]\n        },\n        {\n            \"cell_type\": \"markdown\",\n            \"metadata\": {},\n            \"source\": [\n                \"# Policy inference\\n\",\n                \"\\n\",\n                \"The following example shows how to create a policy from a checkpoint and run inference on a dummy example.\"\n            ]\n        },\n        {\n            \"cell_type\": \"code\",\n            \"execution_count\": null,\n            \"metadata\": {},\n            \"outputs\": [],\n            \"source\": [\n                \"config = _config.get_config(\\\"pi0_fast_droid\\\")\\n\",\n                \"checkpoint_dir = download.maybe_download(\\\"gs://openpi-assets/checkpoints/pi0_fast_droid\\\")\\n\",\n                \"\\n\",\n                \"# Create a trained policy.\\n\",\n                \"policy = _policy_config.create_trained_policy(config, checkpoint_dir)\\n\",\n                \"\\n\",\n                \"# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\\n\",\n                \"example = droid_policy.make_droid_example()\\n\",\n                \"result = policy.infer(example)\\n\",\n                \"\\n\",\n                \"# Delete the policy to free up memory.\\n\",\n                \"del policy\\n\",\n                \"\\n\",\n                \"print(\\\"Actions shape:\\\", result[\\\"actions\\\"].shape)\"\n            ]\n        },\n        {\n            \"cell_type\": \"markdown\",\n            \"metadata\": {},\n            \"source\": [\n                \"# Working with a live model\\n\",\n                \"\\n\",\n                \"\\n\",\n                \"The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\\n\"\n            ]\n        },\n        {\n            \"cell_type\": \"code\",\n            \"execution_count\": null,\n            \"metadata\": {},\n            \"outputs\": [],\n            \"source\": [\n                \"config = _config.get_config(\\\"pi0_aloha_sim\\\")\\n\",\n                \"\\n\",\n                \"checkpoint_dir = download.maybe_download(\\\"gs://openpi-assets/checkpoints/pi0_aloha_sim\\\")\\n\",\n                \"key = jax.random.key(0)\\n\",\n                \"\\n\",\n                \"# Create a model from the checkpoint.\\n\",\n                \"model = config.model.load(_model.restore_params(checkpoint_dir / \\\"params\\\"))\\n\",\n                \"\\n\",\n                \"# We can create fake observations and actions to test the model.\\n\",\n                \"obs, act = config.model.fake_obs(), config.model.fake_act()\\n\",\n                \"\\n\",\n                \"# Sample actions from the model.\\n\",\n                \"loss = model.compute_loss(key, obs, act)\\n\",\n                \"print(\\\"Loss shape:\\\", loss.shape)\"\n            ]\n        },\n        {\n            \"cell_type\": \"markdown\",\n            \"metadata\": {},\n            \"source\": [\n                \"Now, we are going to create a data loader and use a real batch of training data to compute the loss.\"\n            ]\n        },\n        {\n            \"cell_type\": \"code\",\n            \"execution_count\": null,\n            \"metadata\": {},\n            \"outputs\": [],\n            \"source\": [\n                \"# Reduce the batch size to reduce memory usage.\\n\",\n                \"config = dataclasses.replace(config, batch_size=2)\\n\",\n                \"\\n\",\n                \"# Load a single batch of data. This is the same data that will be used during training.\\n\",\n                \"# NOTE: In order to make this example self-contained, we are skipping the normalization step\\n\",\n                \"# since it requires the normalization statistics to be generated using `compute_norm_stats`.\\n\",\n                \"loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\\n\",\n                \"obs, act = next(iter(loader))\\n\",\n                \"\\n\",\n                \"# Sample actions from the model.\\n\",\n                \"loss = model.compute_loss(key, obs, act)\\n\",\n                \"\\n\",\n                \"# Delete the model to free up memory.\\n\",\n                \"del model\\n\",\n                \"\\n\",\n                \"print(\\\"Loss shape:\\\", loss.shape)\"\n            ]\n        }\n    ],\n    \"metadata\": {\n        \"kernelspec\": {\n            \"display_name\": \".venv\",\n            \"language\": \"python\",\n            \"name\": \"python3\"\n        },\n        \"language_info\": {\n            \"codemirror_mode\": {\n                \"name\": \"ipython\",\n                \"version\": 3\n            },\n            \"file_extension\": \".py\",\n            \"mimetype\": \"text/x-python\",\n            \"name\": \"python\",\n            \"nbconvert_exporter\": \"python\",\n            \"pygments_lexer\": \"ipython3\",\n            \"version\": \"3.11.9\"\n        }\n    },\n    \"nbformat\": 4,\n    \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "examples/libero/Dockerfile",
    "content": "# Dockerfile for the LIBERO benchmark.\n\n# Build the container:\n# docker build . -t libero -f examples/libero/Dockerfile\n\n# Run the container:\n# docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash\n\nFROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0\nCOPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/\n\nRUN apt-get update && \\\n    apt-get install -y \\\n    make \\\n    g++ \\\n    clang \\\n    libosmesa6-dev \\\n    libgl1-mesa-glx \\\n    libegl1 \\\n    libglew-dev \\\n    libglfw3-dev \\\n    libgles2-mesa-dev \\\n    libglib2.0-0 \\\n    libsm6 \\\n    libxrender1 \\\n    libxext6\n\nWORKDIR /app\n\n# Copy from the cache instead of linking since it's a mounted volume\nENV UV_LINK_MODE=copy\n\n# Write the virtual environment outside of the project directory so it doesn't\n# leak out of the container when we mount the application code.\nENV UV_PROJECT_ENVIRONMENT=/.venv\n\n# Copy the requirements files so we can install dependencies.\n# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.\n# This strategy is best for development-style usage.\nCOPY ./examples/libero/requirements.txt /tmp/requirements.txt\nCOPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt\nCOPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml\n\n# Install python dependencies.\nRUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT\nRUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match\nENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero\n\n# Create a default config file to avoid an input prompt from LIBERO's init script.\n# https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py\nENV LIBERO_CONFIG_PATH=/tmp/libero\nRUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml\nbenchmark_root: /app/third_party/libero/libero/libero\nbddl_files: /app/third_party/libero/libero/libero/bddl_files\ninit_states: /app/third_party/libero/libero/libero/init_files\ndatasets: /app/third_party/libero/libero/datasets\nassets: /app/third_party/libero/libero/libero/assets\nEOF\n\nRUN mkdir -p /usr/share/glvnd/egl_vendor.d && echo '{\"file_format_version\" : \"1.0.0\", \"ICD\" : { \"library_path\" : \"libEGL_nvidia.so.0\" }}' > /usr/share/glvnd/egl_vendor.d/10_nvidia.json\n\nCMD [\"/bin/bash\", \"-c\", \"source /.venv/bin/activate && python examples/libero/main.py $CLIENT_ARGS\"]\n"
  },
  {
    "path": "examples/libero/README.md",
    "content": "# LIBERO Benchmark\n\nThis example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO\n\nNote: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command.\n\nThis example requires git submodules to be initialized. Don't forget to run:\n\n```bash\ngit submodule update --init --recursive\n```\n\n## With Docker (recommended)\n\n```bash\n# Grant access to the X11 server:\nsudo xhost +local:docker\n\n# To run with the default checkpoint and task suite:\nSERVER_ARGS=\"--env LIBERO\" docker compose -f examples/libero/compose.yml up --build\n\n# To run with glx for Mujoco instead (use this if you have egl errors):\nMUJOCO_GL=glx SERVER_ARGS=\"--env LIBERO\" docker compose -f examples/libero/compose.yml up --build\n```\n\nYou can customize the loaded checkpoint by providing additional `SERVER_ARGS` (see `scripts/serve_policy.py`), and the LIBERO task suite by providing additional `CLIENT_ARGS` (see `examples/libero/main.py`).\nFor example:\n\n```bash\n# To load a custom checkpoint (located in the top-level openpi/ directory):\nexport SERVER_ARGS=\"--env LIBERO policy:checkpoint --policy.config pi05_libero --policy.dir ./my_custom_checkpoint\"\n\n# To run the libero_10 task suite:\nexport CLIENT_ARGS=\"--args.task-suite-name libero_10\"\n```\n\n## Without Docker (not recommended)\n\nTerminal window 1:\n\n```bash\n# Create virtual environment\nuv venv --python 3.8 examples/libero/.venv\nsource examples/libero/.venv/bin/activate\nuv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match\nuv pip install -e packages/openpi-client\nuv pip install -e third_party/libero\nexport PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero\n\n# Run the simulation\npython examples/libero/main.py\n\n# To run with glx for Mujoco instead (use this if you have egl errors):\nMUJOCO_GL=glx python examples/libero/main.py\n```\n\nTerminal window 2:\n\n```bash\n# Run the server\nuv run scripts/serve_policy.py --env LIBERO\n```\n\n## Results\n\nIf you want to reproduce the following numbers, you can evaluate the checkpoint at `gs://openpi-assets/checkpoints/pi05_libero/`. This\ncheckpoint was trained in openpi with the `pi05_libero` config.\n\n| Model | Libero Spatial | Libero Object | Libero Goal | Libero 10 | Average |\n|-------|---------------|---------------|-------------|-----------|---------|\n| π0.5 @ 30k (finetuned) | 98.8 | 98.2 | 98.0 | 92.4 | 96.85\n"
  },
  {
    "path": "examples/libero/compose.yml",
    "content": "# Run with:\n# docker compose -f examples/libero/compose.yml up --build\nservices:\n  runtime:\n    image: libero\n    depends_on:\n      - openpi_server\n    build:\n      context: ../..\n      dockerfile: examples/libero/Dockerfile\n    init: true\n    tty: true\n    network_mode: host\n    privileged: true\n    volumes:\n      - $PWD:/app\n      - ../../data:/data\n      - /tmp/.X11-unix:/tmp/.X11-unix:ro\n    environment:\n      - CLIENT_ARGS\n      - DISPLAY=$DISPLAY\n      - MUJOCO_GL=${MUJOCO_GL:-egl}\n      - MUJOCO_EGL_DEVICE_ID=0\n      - NVIDIA_DRIVER_CAPABILITIES=all\n      - PYOPENGL_PLATFORM=egl\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              count: 1\n              capabilities: [gpu]\n\n  openpi_server:\n    image: openpi_server\n    build:\n      context: ../..\n      dockerfile: scripts/docker/serve_policy.Dockerfile\n    init: true\n    tty: true\n    network_mode: host\n    volumes:\n      - $PWD:/app\n      - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets\n    environment:\n      - SERVER_ARGS\n      - OPENPI_DATA_HOME=/openpi_assets\n      - IS_DOCKER=true\n\n    # Comment out this block if not running on a machine with GPUs.\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              count: 1\n              capabilities: [gpu]\n"
  },
  {
    "path": "examples/libero/convert_libero_data_to_lerobot.py",
    "content": "\"\"\"\nMinimal example script for converting a dataset to LeRobot format.\n\nWe use the Libero dataset (stored in RLDS) for this example, but it can be easily\nmodified for any other data you have saved in a custom format.\n\nUsage:\nuv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data\n\nIf you want to push your dataset to the Hugging Face Hub, you can use the following command:\nuv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub\n\nNote: to run the script, you need to install tensorflow_datasets:\n`uv pip install tensorflow tensorflow_datasets`\n\nYou can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds\nThe resulting dataset will get saved to the $HF_LEROBOT_HOME directory.\nRunning this conversion script will take approximately 30 minutes.\n\"\"\"\n\nimport shutil\n\nfrom lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME\nfrom lerobot.common.datasets.lerobot_dataset import LeRobotDataset\nimport tensorflow_datasets as tfds\nimport tyro\n\nREPO_NAME = \"your_hf_username/libero\"  # Name of the output dataset, also used for the Hugging Face Hub\nRAW_DATASET_NAMES = [\n    \"libero_10_no_noops\",\n    \"libero_goal_no_noops\",\n    \"libero_object_no_noops\",\n    \"libero_spatial_no_noops\",\n]  # For simplicity we will combine multiple Libero datasets into one training dataset\n\n\ndef main(data_dir: str, *, push_to_hub: bool = False):\n    # Clean up any existing dataset in the output directory\n    output_path = HF_LEROBOT_HOME / REPO_NAME\n    if output_path.exists():\n        shutil.rmtree(output_path)\n\n    # Create LeRobot dataset, define features to store\n    # OpenPi assumes that proprio is stored in `state` and actions in `action`\n    # LeRobot assumes that dtype of image data is `image`\n    dataset = LeRobotDataset.create(\n        repo_id=REPO_NAME,\n        robot_type=\"panda\",\n        fps=10,\n        features={\n            \"image\": {\n                \"dtype\": \"image\",\n                \"shape\": (256, 256, 3),\n                \"names\": [\"height\", \"width\", \"channel\"],\n            },\n            \"wrist_image\": {\n                \"dtype\": \"image\",\n                \"shape\": (256, 256, 3),\n                \"names\": [\"height\", \"width\", \"channel\"],\n            },\n            \"state\": {\n                \"dtype\": \"float32\",\n                \"shape\": (8,),\n                \"names\": [\"state\"],\n            },\n            \"actions\": {\n                \"dtype\": \"float32\",\n                \"shape\": (7,),\n                \"names\": [\"actions\"],\n            },\n        },\n        image_writer_threads=10,\n        image_writer_processes=5,\n    )\n\n    # Loop over raw Libero datasets and write episodes to the LeRobot dataset\n    # You can modify this for your own data format\n    for raw_dataset_name in RAW_DATASET_NAMES:\n        raw_dataset = tfds.load(raw_dataset_name, data_dir=data_dir, split=\"train\")\n        for episode in raw_dataset:\n            for step in episode[\"steps\"].as_numpy_iterator():\n                dataset.add_frame(\n                    {\n                        \"image\": step[\"observation\"][\"image\"],\n                        \"wrist_image\": step[\"observation\"][\"wrist_image\"],\n                        \"state\": step[\"observation\"][\"state\"],\n                        \"actions\": step[\"action\"],\n                        \"task\": step[\"language_instruction\"].decode(),\n                    }\n                )\n            dataset.save_episode()\n\n    # Optionally push to the Hugging Face Hub\n    if push_to_hub:\n        dataset.push_to_hub(\n            tags=[\"libero\", \"panda\", \"rlds\"],\n            private=False,\n            push_videos=True,\n            license=\"apache-2.0\",\n        )\n\n\nif __name__ == \"__main__\":\n    tyro.cli(main)\n"
  },
  {
    "path": "examples/libero/main.py",
    "content": "import collections\nimport dataclasses\nimport logging\nimport math\nimport pathlib\n\nimport imageio\nfrom libero.libero import benchmark\nfrom libero.libero import get_libero_path\nfrom libero.libero.envs import OffScreenRenderEnv\nimport numpy as np\nfrom openpi_client import image_tools\nfrom openpi_client import websocket_client_policy as _websocket_client_policy\nimport tqdm\nimport tyro\n\nLIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0]\nLIBERO_ENV_RESOLUTION = 256  # resolution used to render training data\n\n\n@dataclasses.dataclass\nclass Args:\n    #################################################################################################################\n    # Model server parameters\n    #################################################################################################################\n    host: str = \"0.0.0.0\"\n    port: int = 8000\n    resize_size: int = 224\n    replan_steps: int = 5\n\n    #################################################################################################################\n    # LIBERO environment-specific parameters\n    #################################################################################################################\n    task_suite_name: str = (\n        \"libero_spatial\"  # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90\n    )\n    num_steps_wait: int = 10  # Number of steps to wait for objects to stabilize i n sim\n    num_trials_per_task: int = 50  # Number of rollouts per task\n\n    #################################################################################################################\n    # Utils\n    #################################################################################################################\n    video_out_path: str = \"data/libero/videos\"  # Path to save videos\n\n    seed: int = 7  # Random Seed (for reproducibility)\n\n\ndef eval_libero(args: Args) -> None:\n    # Set random seed\n    np.random.seed(args.seed)\n\n    # Initialize LIBERO task suite\n    benchmark_dict = benchmark.get_benchmark_dict()\n    task_suite = benchmark_dict[args.task_suite_name]()\n    num_tasks_in_suite = task_suite.n_tasks\n    logging.info(f\"Task suite: {args.task_suite_name}\")\n\n    pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True)\n\n    if args.task_suite_name == \"libero_spatial\":\n        max_steps = 220  # longest training demo has 193 steps\n    elif args.task_suite_name == \"libero_object\":\n        max_steps = 280  # longest training demo has 254 steps\n    elif args.task_suite_name == \"libero_goal\":\n        max_steps = 300  # longest training demo has 270 steps\n    elif args.task_suite_name == \"libero_10\":\n        max_steps = 520  # longest training demo has 505 steps\n    elif args.task_suite_name == \"libero_90\":\n        max_steps = 400  # longest training demo has 373 steps\n    else:\n        raise ValueError(f\"Unknown task suite: {args.task_suite_name}\")\n\n    client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)\n\n    # Start evaluation\n    total_episodes, total_successes = 0, 0\n    for task_id in tqdm.tqdm(range(num_tasks_in_suite)):\n        # Get task\n        task = task_suite.get_task(task_id)\n\n        # Get default LIBERO initial states\n        initial_states = task_suite.get_task_init_states(task_id)\n\n        # Initialize LIBERO environment and task description\n        env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed)\n\n        # Start episodes\n        task_episodes, task_successes = 0, 0\n        for episode_idx in tqdm.tqdm(range(args.num_trials_per_task)):\n            logging.info(f\"\\nTask: {task_description}\")\n\n            # Reset environment\n            env.reset()\n            action_plan = collections.deque()\n\n            # Set initial states\n            obs = env.set_init_state(initial_states[episode_idx])\n\n            # Setup\n            t = 0\n            replay_images = []\n\n            logging.info(f\"Starting episode {task_episodes+1}...\")\n            while t < max_steps + args.num_steps_wait:\n                try:\n                    # IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects\n                    # and we need to wait for them to fall\n                    if t < args.num_steps_wait:\n                        obs, reward, done, info = env.step(LIBERO_DUMMY_ACTION)\n                        t += 1\n                        continue\n\n                    # Get preprocessed image\n                    # IMPORTANT: rotate 180 degrees to match train preprocessing\n                    img = np.ascontiguousarray(obs[\"agentview_image\"][::-1, ::-1])\n                    wrist_img = np.ascontiguousarray(obs[\"robot0_eye_in_hand_image\"][::-1, ::-1])\n                    img = image_tools.convert_to_uint8(\n                        image_tools.resize_with_pad(img, args.resize_size, args.resize_size)\n                    )\n                    wrist_img = image_tools.convert_to_uint8(\n                        image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size)\n                    )\n\n                    # Save preprocessed image for replay video\n                    replay_images.append(img)\n\n                    if not action_plan:\n                        # Finished executing previous action chunk -- compute new chunk\n                        # Prepare observations dict\n                        element = {\n                            \"observation/image\": img,\n                            \"observation/wrist_image\": wrist_img,\n                            \"observation/state\": np.concatenate(\n                                (\n                                    obs[\"robot0_eef_pos\"],\n                                    _quat2axisangle(obs[\"robot0_eef_quat\"]),\n                                    obs[\"robot0_gripper_qpos\"],\n                                )\n                            ),\n                            \"prompt\": str(task_description),\n                        }\n\n                        # Query model to get action\n                        action_chunk = client.infer(element)[\"actions\"]\n                        assert (\n                            len(action_chunk) >= args.replan_steps\n                        ), f\"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps.\"\n                        action_plan.extend(action_chunk[: args.replan_steps])\n\n                    action = action_plan.popleft()\n\n                    # Execute action in environment\n                    obs, reward, done, info = env.step(action.tolist())\n                    if done:\n                        task_successes += 1\n                        total_successes += 1\n                        break\n                    t += 1\n\n                except Exception as e:\n                    logging.error(f\"Caught exception: {e}\")\n                    break\n\n            task_episodes += 1\n            total_episodes += 1\n\n            # Save a replay video of the episode\n            suffix = \"success\" if done else \"failure\"\n            task_segment = task_description.replace(\" \", \"_\")\n            imageio.mimwrite(\n                pathlib.Path(args.video_out_path) / f\"rollout_{task_segment}_{suffix}.mp4\",\n                [np.asarray(x) for x in replay_images],\n                fps=10,\n            )\n\n            # Log current results\n            logging.info(f\"Success: {done}\")\n            logging.info(f\"# episodes completed so far: {total_episodes}\")\n            logging.info(f\"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)\")\n\n        # Log final results\n        logging.info(f\"Current task success rate: {float(task_successes) / float(task_episodes)}\")\n        logging.info(f\"Current total success rate: {float(total_successes) / float(total_episodes)}\")\n\n    logging.info(f\"Total success rate: {float(total_successes) / float(total_episodes)}\")\n    logging.info(f\"Total episodes: {total_episodes}\")\n\n\ndef _get_libero_env(task, resolution, seed):\n    \"\"\"Initializes and returns the LIBERO environment, along with the task description.\"\"\"\n    task_description = task.language\n    task_bddl_file = pathlib.Path(get_libero_path(\"bddl_files\")) / task.problem_folder / task.bddl_file\n    env_args = {\"bddl_file_name\": task_bddl_file, \"camera_heights\": resolution, \"camera_widths\": resolution}\n    env = OffScreenRenderEnv(**env_args)\n    env.seed(seed)  # IMPORTANT: seed seems to affect object positions even when using fixed initial state\n    return env, task_description\n\n\ndef _quat2axisangle(quat):\n    \"\"\"\n    Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55\n    \"\"\"\n    # clip quaternion\n    if quat[3] > 1.0:\n        quat[3] = 1.0\n    elif quat[3] < -1.0:\n        quat[3] = -1.0\n\n    den = np.sqrt(1.0 - quat[3] * quat[3])\n    if math.isclose(den, 0.0):\n        # This is (close to) a zero degree rotation, immediately return\n        return np.zeros(3)\n\n    return (quat[:3] * 2.0 * math.acos(quat[3])) / den\n\n\nif __name__ == \"__main__\":\n    logging.basicConfig(level=logging.INFO)\n    tyro.cli(eval_libero)\n"
  },
  {
    "path": "examples/libero/requirements.in",
    "content": "imageio[ffmpeg]\nnumpy==1.22.4\ntqdm\ntyro\nPyYaml\nopencv-python==4.6.0.66\ntorch==1.11.0+cu113\ntorchvision==0.12.0+cu113\ntorchaudio==0.11.0+cu113\nrobosuite==1.4.1\nmatplotlib==3.5.3\n"
  },
  {
    "path": "examples/libero/requirements.txt",
    "content": "# This file was autogenerated by uv via the following command:\n#    uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match\nabsl-py==2.1.0\n    # via mujoco\ncertifi==2024.12.14\n    # via requests\ncharset-normalizer==3.4.0\n    # via requests\ncycler==0.12.1\n    # via matplotlib\ndocstring-parser==0.16\n    # via tyro\netils==1.3.0\n    # via mujoco\neval-type-backport==0.2.0\n    # via tyro\nevdev==1.7.1\n    # via pynput\nfonttools==4.55.3\n    # via matplotlib\nglfw==1.12.0\n    # via mujoco\nidna==3.10\n    # via requests\nimageio==2.35.1\n    # via -r examples/libero/requirements.in\nimageio-ffmpeg==0.5.1\n    # via imageio\nimportlib-metadata==8.5.0\n    # via typeguard\nimportlib-resources==6.4.5\n    # via etils\nkiwisolver==1.4.7\n    # via matplotlib\nllvmlite==0.36.0\n    # via numba\nmarkdown-it-py==3.0.0\n    # via rich\nmatplotlib==3.5.3\n    # via -r examples/libero/requirements.in\nmdurl==0.1.2\n    # via markdown-it-py\nmujoco==3.2.3\n    # via robosuite\nnumba==0.53.1\n    # via robosuite\nnumpy==1.22.4\n    # via\n    #   -r examples/libero/requirements.in\n    #   imageio\n    #   matplotlib\n    #   mujoco\n    #   numba\n    #   opencv-python\n    #   robosuite\n    #   scipy\n    #   torchvision\nopencv-python==4.6.0.66\n    # via\n    #   -r examples/libero/requirements.in\n    #   robosuite\npackaging==24.2\n    # via matplotlib\npillow==10.4.0\n    # via\n    #   imageio\n    #   matplotlib\n    #   robosuite\n    #   torchvision\npsutil==6.1.0\n    # via imageio\npygments==2.18.0\n    # via rich\npynput==1.7.7\n    # via robosuite\npyopengl==3.1.7\n    # via mujoco\npyparsing==3.1.4\n    # via matplotlib\npython-dateutil==2.9.0.post0\n    # via matplotlib\npython-xlib==0.33\n    # via pynput\npyyaml==6.0.2\n    # via -r examples/libero/requirements.in\nrequests==2.32.3\n    # via torchvision\nrich==13.9.4\n    # via tyro\nrobosuite==1.4.1\n    # via -r examples/libero/requirements.in\nscipy==1.10.1\n    # via robosuite\nsetuptools==75.3.0\n    # via\n    #   imageio-ffmpeg\n    #   numba\nshtab==1.7.1\n    # via tyro\nsix==1.17.0\n    # via\n    #   pynput\n    #   python-dateutil\n    #   python-xlib\ntermcolor==2.4.0\n    # via robosuite\ntorch==1.11.0+cu113\n    # via\n    #   -r examples/libero/requirements.in\n    #   torchaudio\n    #   torchvision\ntorchaudio==0.11.0+cu113\n    # via -r examples/libero/requirements.in\ntorchvision==0.12.0+cu113\n    # via -r examples/libero/requirements.in\ntqdm==4.67.1\n    # via -r examples/libero/requirements.in\ntypeguard==4.4.0\n    # via tyro\ntyping-extensions==4.12.2\n    # via\n    #   etils\n    #   rich\n    #   torch\n    #   torchvision\n    #   typeguard\n    #   tyro\ntyro==0.9.2\n    # via -r examples/libero/requirements.in\nurllib3==2.2.3\n    # via requests\nzipp==3.20.2\n    # via\n    #   etils\n    #   importlib-metadata\n    #   importlib-resources\n"
  },
  {
    "path": "examples/policy_records.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import pathlib\\n\",\n    \"\\n\",\n    \"import numpy as np\\n\",\n    \"\\n\",\n    \"record_path = pathlib.Path(\\\"../policy_records\\\")\\n\",\n    \"num_steps = len(list(record_path.glob(\\\"step_*.npy\\\")))\\n\",\n    \"\\n\",\n    \"records = []\\n\",\n    \"for i in range(num_steps):\\n\",\n    \"    record = np.load(record_path / f\\\"step_{i}.npy\\\", allow_pickle=True).item()\\n\",\n    \"    records.append(record)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print(\\\"length of records\\\", len(records))\\n\",\n    \"print(\\\"keys in records\\\", records[0].keys())\\n\",\n    \"\\n\",\n    \"for k in records[0]:\\n\",\n    \"    print(f\\\"{k} shape: {records[0][k].shape}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from PIL import Image\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def get_image(step: int, idx: int = 0):\\n\",\n    \"    img = (255 * records[step][\\\"inputs/image\\\"]).astype(np.uint8)\\n\",\n    \"    return img[idx].transpose(1, 2, 0)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def show_image(step: int, idx_lst: list[int]):\\n\",\n    \"    imgs = [get_image(step, idx) for idx in idx_lst]\\n\",\n    \"    return Image.fromarray(np.hstack(imgs))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"for i in range(2):\\n\",\n    \"    display(show_image(i, [0]))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import pandas as pd\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def get_axis(name, axis):\\n\",\n    \"    return np.array([record[name][axis] for record in records])\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# qpos is [..., 14] of type float:\\n\",\n    \"# 0-5: left arm joint angles\\n\",\n    \"# 6: left arm gripper\\n\",\n    \"# 7-12: right arm joint angles\\n\",\n    \"# 13: right arm gripper\\n\",\n    \"names = [(\\\"left_joint\\\", 6), (\\\"left_gripper\\\", 1), (\\\"right_joint\\\", 6), (\\\"right_gripper\\\", 1)]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def make_data():\\n\",\n    \"    cur_dim = 0\\n\",\n    \"    in_data = {}\\n\",\n    \"    out_data = {}\\n\",\n    \"    for name, dim_size in names:\\n\",\n    \"        for i in range(dim_size):\\n\",\n    \"            in_data[f\\\"{name}_{i}\\\"] = get_axis(\\\"inputs/qpos\\\", cur_dim)\\n\",\n    \"            out_data[f\\\"{name}_{i}\\\"] = get_axis(\\\"outputs/qpos\\\", cur_dim)\\n\",\n    \"            cur_dim += 1\\n\",\n    \"    return pd.DataFrame(in_data), pd.DataFrame(out_data)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"in_data, out_data = make_data()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"for name in in_data.columns:\\n\",\n    \"    data = pd.DataFrame({f\\\"in_{name}\\\": in_data[name], f\\\"out_{name}\\\": out_data[name]})\\n\",\n    \"    data.plot()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \".venv\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.9\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "examples/simple_client/Dockerfile",
    "content": "# Dockerfile for the simple client.\n\n# Build the container:\n# docker build . -t simple_client -f examples/simple_client/Dockerfile\n\n# Run the container:\n# docker run --rm -it --network=host -v .:/app simple_client /bin/bash\n\nFROM python:3.7-slim\nCOPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/\n\nWORKDIR /app\n\n# Copy from the cache instead of linking since it's a mounted volume\nENV UV_LINK_MODE=copy\n\n# Write the virtual environment outside of the project directory so it doesn't\n# leak out of the container when we mount the application code.\nENV UV_PROJECT_ENVIRONMENT=/.venv\n\n# Copy the requirements files so we can install dependencies.\n# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.\n# This strategy is best for development-style usage.\nCOPY ./examples/simple_client/requirements.txt /tmp/requirements.txt\nCOPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml\n\n# Install python dependencies.\nRUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT\nRUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml\nENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src\n\nCMD /bin/bash -c \"source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS\"\n"
  },
  {
    "path": "examples/simple_client/README.md",
    "content": "# Simple Client\n\nA minimal client that sends observations to the server and prints the inference rate.\n\nYou can specify which runtime environment to use using the `--env` flag. You can see the available options by running:\n\n```bash\nuv run examples/simple_client/main.py --help\n```\n\n## With Docker\n\n```bash\nexport SERVER_ARGS=\"--env ALOHA_SIM\"\ndocker compose -f examples/simple_client/compose.yml up --build\n```\n\n## Without Docker\n\nTerminal window 1:\n\n```bash\nuv run examples/simple_client/main.py --env DROID\n```\n\nTerminal window 2:\n\n```bash\nuv run scripts/serve_policy.py --env DROID\n```\n"
  },
  {
    "path": "examples/simple_client/compose.yml",
    "content": "# Run with:\n# docker compose -f examples/simple_client/compose.yml up --build\nservices:\n  runtime:\n    image: simple_client\n    depends_on:\n      - openpi_server\n    build:\n      context: ../..\n      dockerfile: examples/simple_client/Dockerfile\n    init: true\n    tty: true\n    network_mode: host\n    volumes:\n      - $PWD:/app\n    environment:\n      - SERVER_ARGS  \n\n  openpi_server:\n    image: openpi_server\n    build:\n      context: ../..\n      dockerfile: scripts/docker/serve_policy.Dockerfile\n    init: true\n    tty: true\n    network_mode: host\n    volumes:\n      - $PWD:/app\n      - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets\n    environment:\n      - SERVER_ARGS\n      - OPENPI_DATA_HOME=/openpi_assets\n      - IS_DOCKER=true\n\n    # Comment out this block if not running on a machine with GPUs.\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              count: 1\n              capabilities: [gpu]\n"
  },
  {
    "path": "examples/simple_client/main.py",
    "content": "import dataclasses\nimport enum\nimport logging\nimport pathlib\nimport time\n\nimport numpy as np\nfrom openpi_client import websocket_client_policy as _websocket_client_policy\nimport polars as pl\nimport rich\nimport tqdm\nimport tyro\n\nlogger = logging.getLogger(__name__)\n\n\nclass EnvMode(enum.Enum):\n    \"\"\"Supported environments.\"\"\"\n\n    ALOHA = \"aloha\"\n    ALOHA_SIM = \"aloha_sim\"\n    DROID = \"droid\"\n    LIBERO = \"libero\"\n\n\n@dataclasses.dataclass\nclass Args:\n    \"\"\"Command line arguments.\"\"\"\n\n    # Host and port to connect to the server.\n    host: str = \"0.0.0.0\"\n    # Port to connect to the server. If None, the server will use the default port.\n    port: int | None = 8000\n    # API key to use for the server.\n    api_key: str | None = None\n    # Number of steps to run the policy for.\n    num_steps: int = 20\n    # Path to save the timings to a parquet file. (e.g., timing.parquet)\n    timing_file: pathlib.Path | None = None\n    # Environment to run the policy in.\n    env: EnvMode = EnvMode.ALOHA_SIM\n\n\nclass TimingRecorder:\n    \"\"\"Records timing measurements for different keys.\"\"\"\n\n    def __init__(self) -> None:\n        self._timings: dict[str, list[float]] = {}\n\n    def record(self, key: str, time_ms: float) -> None:\n        \"\"\"Record a timing measurement for the given key.\"\"\"\n        if key not in self._timings:\n            self._timings[key] = []\n        self._timings[key].append(time_ms)\n\n    def get_stats(self, key: str) -> dict[str, float]:\n        \"\"\"Get statistics for the given key.\"\"\"\n        times = self._timings[key]\n        return {\n            \"mean\": float(np.mean(times)),\n            \"std\": float(np.std(times)),\n            \"p25\": float(np.quantile(times, 0.25)),\n            \"p50\": float(np.quantile(times, 0.50)),\n            \"p75\": float(np.quantile(times, 0.75)),\n            \"p90\": float(np.quantile(times, 0.90)),\n            \"p95\": float(np.quantile(times, 0.95)),\n            \"p99\": float(np.quantile(times, 0.99)),\n        }\n\n    def print_all_stats(self) -> None:\n        \"\"\"Print statistics for all keys in a concise format.\"\"\"\n\n        table = rich.table.Table(\n            title=\"[bold blue]Timing Statistics[/bold blue]\",\n            show_header=True,\n            header_style=\"bold white\",\n            border_style=\"blue\",\n            title_justify=\"center\",\n        )\n\n        # Add metric column with custom styling\n        table.add_column(\"Metric\", style=\"cyan\", justify=\"left\", no_wrap=True)\n\n        # Add statistical columns with consistent styling\n        stat_columns = [\n            (\"Mean\", \"yellow\", \"mean\"),\n            (\"Std\", \"yellow\", \"std\"),\n            (\"P25\", \"magenta\", \"p25\"),\n            (\"P50\", \"magenta\", \"p50\"),\n            (\"P75\", \"magenta\", \"p75\"),\n            (\"P90\", \"magenta\", \"p90\"),\n            (\"P95\", \"magenta\", \"p95\"),\n            (\"P99\", \"magenta\", \"p99\"),\n        ]\n\n        for name, style, _ in stat_columns:\n            table.add_column(name, justify=\"right\", style=style, no_wrap=True)\n\n        # Add rows for each metric with formatted values\n        for key in sorted(self._timings.keys()):\n            stats = self.get_stats(key)\n            values = [f\"{stats[key]:.1f}\" for _, _, key in stat_columns]\n            table.add_row(key, *values)\n\n        # Print with custom console settings\n        console = rich.console.Console(width=None, highlight=True)\n        console.print(table)\n\n    def write_parquet(self, path: pathlib.Path) -> None:\n        \"\"\"Save the timings to a parquet file.\"\"\"\n        logger.info(f\"Writing timings to {path}\")\n        frame = pl.DataFrame(self._timings)\n        path.parent.mkdir(parents=True, exist_ok=True)\n        frame.write_parquet(path)\n\n\ndef main(args: Args) -> None:\n    obs_fn = {\n        EnvMode.ALOHA: _random_observation_aloha,\n        EnvMode.ALOHA_SIM: _random_observation_aloha,\n        EnvMode.DROID: _random_observation_droid,\n        EnvMode.LIBERO: _random_observation_libero,\n    }[args.env]\n\n    policy = _websocket_client_policy.WebsocketClientPolicy(\n        host=args.host,\n        port=args.port,\n        api_key=args.api_key,\n    )\n    logger.info(f\"Server metadata: {policy.get_server_metadata()}\")\n\n    # Send a few observations to make sure the model is loaded.\n    for _ in range(2):\n        policy.infer(obs_fn())\n\n    timing_recorder = TimingRecorder()\n\n    for _ in tqdm.trange(args.num_steps, desc=\"Running policy\"):\n        inference_start = time.time()\n        action = policy.infer(obs_fn())\n        timing_recorder.record(\"client_infer_ms\", 1000 * (time.time() - inference_start))\n        for key, value in action.get(\"server_timing\", {}).items():\n            timing_recorder.record(f\"server_{key}\", value)\n        for key, value in action.get(\"policy_timing\", {}).items():\n            timing_recorder.record(f\"policy_{key}\", value)\n\n    timing_recorder.print_all_stats()\n\n    if args.timing_file is not None:\n        timing_recorder.write_parquet(args.timing_file)\n\n\ndef _random_observation_aloha() -> dict:\n    return {\n        \"state\": np.ones((14,)),\n        \"images\": {\n            \"cam_high\": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),\n            \"cam_low\": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),\n            \"cam_left_wrist\": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),\n            \"cam_right_wrist\": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),\n        },\n        \"prompt\": \"do something\",\n    }\n\n\ndef _random_observation_droid() -> dict:\n    return {\n        \"observation/exterior_image_1_left\": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),\n        \"observation/wrist_image_left\": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),\n        \"observation/joint_position\": np.random.rand(7),\n        \"observation/gripper_position\": np.random.rand(1),\n        \"prompt\": \"do something\",\n    }\n\n\ndef _random_observation_libero() -> dict:\n    return {\n        \"observation/state\": np.random.rand(8),\n        \"observation/image\": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),\n        \"observation/wrist_image\": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),\n        \"prompt\": \"do something\",\n    }\n\n\nif __name__ == \"__main__\":\n    logging.basicConfig(level=logging.INFO)\n    main(tyro.cli(Args))\n"
  },
  {
    "path": "examples/simple_client/requirements.in",
    "content": "numpy>=1.22.4,<2.0.0\nrich\ntqdm\ntyro\npolars"
  },
  {
    "path": "examples/simple_client/requirements.txt",
    "content": "# This file was autogenerated by uv via the following command:\n#    uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.11.9\ndocstring-parser==0.16\n    # via tyro\nmarkdown-it-py==3.0.0\n    # via rich\nmdurl==0.1.2\n    # via markdown-it-py\nnumpy==1.26.4\n    # via -r examples/simple_client/requirements.in\npolars==1.30.0\n    # via -r examples/simple_client/requirements.in\npygments==2.19.1\n    # via rich\nrich==14.0.0\n    # via\n    #   -r examples/simple_client/requirements.in\n    #   tyro\nshtab==1.7.2\n    # via tyro\ntqdm==4.67.1\n    # via -r examples/simple_client/requirements.in\ntypeguard==4.4.2\n    # via tyro\ntyping-extensions==4.13.2\n    # via\n    #   typeguard\n    #   tyro\ntyro==0.9.22\n    # via -r examples/simple_client/requirements.in\n"
  },
  {
    "path": "examples/ur5/README.md",
    "content": "# UR5 Example\n\nBelow we provide an outline of how to implement the key components mentioned in the \"Finetune on your data\" section of the [README](../README.md) for finetuning on UR5 datasets.\n\nFirst, we will define the `UR5Inputs` and `UR5Outputs` classes, which map the UR5 environment to the model and vice versa. Check the corresponding files in `src/openpi/policies/libero_policy.py` for comments explaining each line.\n\n```python\n\n@dataclasses.dataclass(frozen=True)\nclass UR5Inputs(transforms.DataTransformFn):\n\n    model_type: _model.ModelType = _model.ModelType.PI0\n\n    def __call__(self, data: dict) -> dict:\n        # First, concatenate the joints and gripper into the state vector.\n        state = np.concatenate([data[\"joints\"], data[\"gripper\"]])\n\n        # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically\n        # stores as float32 (C,H,W), gets skipped for policy inference.\n        base_image = _parse_image(data[\"base_rgb\"])\n        wrist_image = _parse_image(data[\"wrist_rgb\"])\n\n        # Create inputs dict.\n        inputs = {\n            \"state\": state,\n            \"image\": {\n                \"base_0_rgb\": base_image,\n                \"left_wrist_0_rgb\": wrist_image,\n                # Since there is no right wrist, replace with zeros\n                \"right_wrist_0_rgb\": np.zeros_like(base_image),\n            },\n            \"image_mask\": {\n                \"base_0_rgb\": np.True_,\n                \"left_wrist_0_rgb\": np.True_,\n                # Since the \"slot\" for the right wrist is not used, this mask is set\n                # to False\n                \"right_wrist_0_rgb\": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_,\n            },\n        }\n\n        if \"actions\" in data:\n            inputs[\"actions\"] = data[\"actions\"]\n\n        # Pass the prompt (aka language instruction) to the model.\n        if \"prompt\" in data:\n            inputs[\"prompt\"] = data[\"prompt\"]\n\n        return inputs\n\n\n@dataclasses.dataclass(frozen=True)\nclass UR5Outputs(transforms.DataTransformFn):\n\n    def __call__(self, data: dict) -> dict:\n        # Since the robot has 7 action dimensions (6 DoF + gripper), return the first 7 dims\n        return {\"actions\": np.asarray(data[\"actions\"][:, :7])}\n\n```\n\nNext, we will define the `UR5DataConfig` class, which defines how to process raw UR5 data from LeRobot dataset for training. For a full example, see the `LeRobotLiberoDataConfig` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).\n\n```python\n\n@dataclasses.dataclass(frozen=True)\nclass LeRobotUR5DataConfig(DataConfigFactory):\n\n    @override\n    def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:\n        # Boilerplate for remapping keys from the LeRobot dataset. We assume no renaming needed here.\n        repack_transform = _transforms.Group(\n            inputs=[\n                _transforms.RepackTransform(\n                    {\n                        \"base_rgb\": \"image\",\n                        \"wrist_rgb\": \"wrist_image\",\n                        \"joints\": \"joints\",\n                        \"gripper\": \"gripper\",\n                        \"prompt\": \"prompt\",\n                    }\n                )\n            ]\n        )\n\n        # These transforms are the ones we wrote earlier.\n        data_transforms = _transforms.Group(\n            inputs=[UR5Inputs(action_dim=model_config.action_dim, model_type=model_config.model_type)],\n            outputs=[UR5Outputs()],\n        )\n\n        # Convert absolute actions to delta actions.\n        # By convention, we do not convert the gripper action (7th dimension).\n        delta_action_mask = _transforms.make_bool_mask(6, -1)\n        data_transforms = data_transforms.push(\n            inputs=[_transforms.DeltaActions(delta_action_mask)],\n            outputs=[_transforms.AbsoluteActions(delta_action_mask)],\n        )\n\n        # Model transforms include things like tokenizing the prompt and action targets\n        # You do not need to change anything here for your own dataset.\n        model_transforms = ModelTransformFactory()(model_config)\n\n        # We return all data transforms for training and inference. No need to change anything here.\n        return dataclasses.replace(\n            self.create_base_config(assets_dirs),\n            repack_transforms=repack_transform,\n            data_transforms=data_transforms,\n            model_transforms=model_transforms,\n        )\n\n```\n\nFinally, we define the TrainConfig for our UR5 dataset. Here, we define a config for fine-tuning pi0 on our UR5 dataset. See the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py) for more examples, e.g. for pi0-FAST or for LoRA fine-tuning.\n\n```python\nTrainConfig(\n    name=\"pi0_ur5\",\n    model=pi0.Pi0Config(),\n    data=LeRobotUR5DataConfig(\n        repo_id=\"your_username/ur5_dataset\",\n        # This config lets us reload the UR5 normalization stats from the base model checkpoint.\n        # Reloading normalization stats can help transfer pre-trained models to new environments.\n        # See the [norm_stats.md](../docs/norm_stats.md) file for more details.\n        assets=AssetsConfig(\n            assets_dir=\"gs://openpi-assets/checkpoints/pi0_base/assets\",\n            asset_id=\"ur5e\",\n        ),\n        base_config=DataConfig(\n            # This flag determines whether we load the prompt (i.e. the task instruction) from the\n            # ``task`` field in the LeRobot dataset. The recommended setting is True.\n            prompt_from_task=True,\n        ),\n    ),\n    # Load the pi0 base model checkpoint.\n    weight_loader=weight_loaders.CheckpointWeightLoader(\"gs://openpi-assets/checkpoints/pi0_base/params\"),\n    num_train_steps=30_000,\n)\n```\n\n\n\n\n\n"
  },
  {
    "path": "packages/openpi-client/pyproject.toml",
    "content": "[project]\nname = \"openpi-client\"\nversion = \"0.1.0\"\nrequires-python = \">=3.7\"\ndependencies = [\n    \"dm-tree>=0.1.8\",\n    \"msgpack>=1.0.5\",\n    \"numpy>=1.22.4,<2.0.0\",\n    \"pillow>=9.0.0\",\n    \"tree>=0.2.4\",\n    \"websockets>=11.0\",\n]\n\n[build-system]\nrequires = [\"hatchling\"]\nbuild-backend = \"hatchling.build\"\n\n[tool.uv]\ndev-dependencies = [\"pytest>=8.3.4\"]\n\n[tool.ruff]\nline-length = 120\ntarget-version = \"py37\"\n"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/__init__.py",
    "content": "__version__ = \"0.1.0\"\n"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/action_chunk_broker.py",
    "content": "from typing import Dict\n\nimport numpy as np\nimport tree\nfrom typing_extensions import override\n\nfrom openpi_client import base_policy as _base_policy\n\n\nclass ActionChunkBroker(_base_policy.BasePolicy):\n    \"\"\"Wraps a policy to return action chunks one-at-a-time.\n\n    Assumes that the first dimension of all action fields is the chunk size.\n\n    A new inference call to the inner policy is only made when the current\n    list of chunks is exhausted.\n    \"\"\"\n\n    def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):\n        self._policy = policy\n        self._action_horizon = action_horizon\n        self._cur_step: int = 0\n\n        self._last_results: Dict[str, np.ndarray] | None = None\n\n    @override\n    def infer(self, obs: Dict) -> Dict:  # noqa: UP006\n        if self._last_results is None:\n            self._last_results = self._policy.infer(obs)\n            self._cur_step = 0\n\n        def slicer(x):\n            if isinstance(x, np.ndarray):\n                return x[self._cur_step, ...]\n            else:\n                return x\n\n        results = tree.map_structure(slicer, self._last_results)\n        self._cur_step += 1\n\n        if self._cur_step >= self._action_horizon:\n            self._last_results = None\n\n        return results\n\n    @override\n    def reset(self) -> None:\n        self._policy.reset()\n        self._last_results = None\n        self._cur_step = 0\n"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/base_policy.py",
    "content": "import abc\nfrom typing import Dict\n\n\nclass BasePolicy(abc.ABC):\n    @abc.abstractmethod\n    def infer(self, obs: Dict) -> Dict:\n        \"\"\"Infer actions from observations.\"\"\"\n\n    def reset(self) -> None:\n        \"\"\"Reset the policy to its initial state.\"\"\"\n        pass\n"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/image_tools.py",
    "content": "import numpy as np\nfrom PIL import Image\n\n\ndef convert_to_uint8(img: np.ndarray) -> np.ndarray:\n    \"\"\"Converts an image to uint8 if it is a float image.\n\n    This is important for reducing the size of the image when sending it over the network.\n    \"\"\"\n    if np.issubdtype(img.dtype, np.floating):\n        img = (255 * img).astype(np.uint8)\n    return img\n\n\ndef resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray:\n    \"\"\"Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.\n\n    Args:\n        images: A batch of images in [..., height, width, channel] format.\n        height: The target height of the image.\n        width: The target width of the image.\n        method: The interpolation method to use. Default is bilinear.\n\n    Returns:\n        The resized images in [..., height, width, channel].\n    \"\"\"\n    # If the images are already the correct size, return them as is.\n    if images.shape[-3:-1] == (height, width):\n        return images\n\n    original_shape = images.shape\n\n    images = images.reshape(-1, *original_shape[-3:])\n    resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images])\n    return resized.reshape(*original_shape[:-3], *resized.shape[-3:])\n\n\ndef _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image:\n    \"\"\"Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and\n    width without distortion by padding with zeros.\n\n    Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].\n    \"\"\"\n    cur_width, cur_height = image.size\n    if cur_width == width and cur_height == height:\n        return image  # No need to resize if the image is already the correct size.\n\n    ratio = max(cur_width / width, cur_height / height)\n    resized_height = int(cur_height / ratio)\n    resized_width = int(cur_width / ratio)\n    resized_image = image.resize((resized_width, resized_height), resample=method)\n\n    zero_image = Image.new(resized_image.mode, (width, height), 0)\n    pad_height = max(0, int((height - resized_height) / 2))\n    pad_width = max(0, int((width - resized_width) / 2))\n    zero_image.paste(resized_image, (pad_width, pad_height))\n    assert zero_image.size == (width, height)\n    return zero_image\n"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/image_tools_test.py",
    "content": "import numpy as np\n\nimport openpi_client.image_tools as image_tools\n\n\ndef test_resize_with_pad_shapes():\n    # Test case 1: Resize image with larger dimensions\n    images = np.zeros((2, 10, 10, 3), dtype=np.uint8)  # Input images of shape (batch_size, height, width, channels)\n    height = 20\n    width = 20\n    resized_images = image_tools.resize_with_pad(images, height, width)\n    assert resized_images.shape == (2, height, width, 3)\n    assert np.all(resized_images == 0)\n\n    # Test case 2: Resize image with smaller dimensions\n    images = np.zeros((3, 30, 30, 3), dtype=np.uint8)\n    height = 15\n    width = 15\n    resized_images = image_tools.resize_with_pad(images, height, width)\n    assert resized_images.shape == (3, height, width, 3)\n    assert np.all(resized_images == 0)\n\n    # Test case 3: Resize image with the same dimensions\n    images = np.zeros((1, 50, 50, 3), dtype=np.uint8)\n    height = 50\n    width = 50\n    resized_images = image_tools.resize_with_pad(images, height, width)\n    assert resized_images.shape == (1, height, width, 3)\n    assert np.all(resized_images == 0)\n\n    # Test case 3: Resize image with odd-numbered padding\n    images = np.zeros((1, 256, 320, 3), dtype=np.uint8)\n    height = 60\n    width = 80\n    resized_images = image_tools.resize_with_pad(images, height, width)\n    assert resized_images.shape == (1, height, width, 3)\n    assert np.all(resized_images == 0)\n"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/msgpack_numpy.py",
    "content": "\"\"\"Adds NumPy array support to msgpack.\n\nmsgpack is good for (de)serializing data over a network for multiple reasons:\n- msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution)\n- msgpack is widely used and has good cross-language support\n- msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed\n    languages like Python and JavaScript\n- msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster\n    than pickle for serializing large arrays using the below strategy\n\nThe code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is\nthat it falls back to pickle for object arrays.\n\"\"\"\n\nimport functools\n\nimport msgpack\nimport numpy as np\n\n\ndef pack_array(obj):\n    if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in (\"V\", \"O\", \"c\"):\n        raise ValueError(f\"Unsupported dtype: {obj.dtype}\")\n\n    if isinstance(obj, np.ndarray):\n        return {\n            b\"__ndarray__\": True,\n            b\"data\": obj.tobytes(),\n            b\"dtype\": obj.dtype.str,\n            b\"shape\": obj.shape,\n        }\n\n    if isinstance(obj, np.generic):\n        return {\n            b\"__npgeneric__\": True,\n            b\"data\": obj.item(),\n            b\"dtype\": obj.dtype.str,\n        }\n\n    return obj\n\n\ndef unpack_array(obj):\n    if b\"__ndarray__\" in obj:\n        return np.ndarray(buffer=obj[b\"data\"], dtype=np.dtype(obj[b\"dtype\"]), shape=obj[b\"shape\"])\n\n    if b\"__npgeneric__\" in obj:\n        return np.dtype(obj[b\"dtype\"]).type(obj[b\"data\"])\n\n    return obj\n\n\nPacker = functools.partial(msgpack.Packer, default=pack_array)\npackb = functools.partial(msgpack.packb, default=pack_array)\n\nUnpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)\nunpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)\n"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/msgpack_numpy_test.py",
    "content": "import numpy as np\nimport pytest\nimport tree\n\nfrom openpi_client import msgpack_numpy\n\n\ndef _check(expected, actual):\n    if isinstance(expected, np.ndarray):\n        assert expected.shape == actual.shape\n        assert expected.dtype == actual.dtype\n        assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == \"f\")\n    else:\n        assert expected == actual\n\n\n@pytest.mark.parametrize(\n    \"data\",\n    [\n        1,  # int\n        1.0,  # float\n        \"hello\",  # string\n        np.bool_(True),  # boolean scalar\n        np.array([1, 2, 3])[0],  # int scalar\n        np.str_(\"asdf\"),  # string scalar\n        [1, 2, 3],  # list\n        {\"key\": \"value\"},  # dict\n        {\"key\": [1, 2, 3]},  # nested dict\n        np.array(1.0),  # 0D array\n        np.array([1, 2, 3], dtype=np.int32),  # 1D integer array\n        np.array([\"asdf\", \"qwer\"]),  # string array\n        np.array([True, False]),  # boolean array\n        np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32),  # 2D float array\n        np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16),  # 3D integer array\n        np.array([np.nan, np.inf, -np.inf]),  # special float values\n        {\"arr\": np.array([1, 2, 3]), \"nested\": {\"arr\": np.array([4, 5, 6])}},  # nested dict with arrays\n        [np.array([1, 2]), np.array([3, 4])],  # list of arrays\n        np.zeros((3, 4, 5), dtype=np.float32),  # 3D zeros\n        np.ones((2, 3), dtype=np.float64),  # 2D ones with double precision\n    ],\n)\ndef test_pack_unpack(data):\n    packed = msgpack_numpy.packb(data)\n    unpacked = msgpack_numpy.unpackb(packed)\n    tree.map_structure(_check, data, unpacked)\n"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/runtime/agent.py",
    "content": "import abc\n\n\nclass Agent(abc.ABC):\n    \"\"\"An Agent is the thing with agency, i.e. the entity that makes decisions.\n\n    Agents receive observations about the state of the world, and return actions\n    to take in response.\n    \"\"\"\n\n    @abc.abstractmethod\n    def get_action(self, observation: dict) -> dict:\n        \"\"\"Query the agent for the next action.\"\"\"\n\n    @abc.abstractmethod\n    def reset(self) -> None:\n        \"\"\"Reset the agent to its initial state.\"\"\"\n"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py",
    "content": "from typing_extensions import override\n\nfrom openpi_client import base_policy as _base_policy\nfrom openpi_client.runtime import agent as _agent\n\n\nclass PolicyAgent(_agent.Agent):\n    \"\"\"An agent that uses a policy to determine actions.\"\"\"\n\n    def __init__(self, policy: _base_policy.BasePolicy) -> None:\n        self._policy = policy\n\n    @override\n    def get_action(self, observation: dict) -> dict:\n        return self._policy.infer(observation)\n\n    def reset(self) -> None:\n        self._policy.reset()\n"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/runtime/environment.py",
    "content": "import abc\n\n\nclass Environment(abc.ABC):\n    \"\"\"An Environment represents the robot and the environment it inhabits.\n\n    The primary contract of environments is that they can be queried for observations\n    about their state, and have actions applied to them to change that state.\n    \"\"\"\n\n    @abc.abstractmethod\n    def reset(self) -> None:\n        \"\"\"Reset the environment to its initial state.\n\n        This will be called once before starting each episode.\n        \"\"\"\n\n    @abc.abstractmethod\n    def is_episode_complete(self) -> bool:\n        \"\"\"Allow the environment to signal that the episode is complete.\n\n        This will be called after each step. It should return `True` if the episode is\n        complete (either successfully or unsuccessfully), and `False` otherwise.\n        \"\"\"\n\n    @abc.abstractmethod\n    def get_observation(self) -> dict:\n        \"\"\"Query the environment for the current state.\"\"\"\n\n    @abc.abstractmethod\n    def apply_action(self, action: dict) -> None:\n        \"\"\"Take an action in the environment.\"\"\"\n"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/runtime/runtime.py",
    "content": "import logging\nimport threading\nimport time\n\nfrom openpi_client.runtime import agent as _agent\nfrom openpi_client.runtime import environment as _environment\nfrom openpi_client.runtime import subscriber as _subscriber\n\n\nclass Runtime:\n    \"\"\"The core module orchestrating interactions between key components of the system.\"\"\"\n\n    def __init__(\n        self,\n        environment: _environment.Environment,\n        agent: _agent.Agent,\n        subscribers: list[_subscriber.Subscriber],\n        max_hz: float = 0,\n        num_episodes: int = 1,\n        max_episode_steps: int = 0,\n    ) -> None:\n        self._environment = environment\n        self._agent = agent\n        self._subscribers = subscribers\n        self._max_hz = max_hz\n        self._num_episodes = num_episodes\n        self._max_episode_steps = max_episode_steps\n\n        self._in_episode = False\n        self._episode_steps = 0\n\n    def run(self) -> None:\n        \"\"\"Runs the runtime loop continuously until stop() is called or the environment is done.\"\"\"\n        for _ in range(self._num_episodes):\n            self._run_episode()\n\n        # Final reset, this is important for real environments to move the robot to its home position.\n        self._environment.reset()\n\n    def run_in_new_thread(self) -> threading.Thread:\n        \"\"\"Runs the runtime loop in a new thread.\"\"\"\n        thread = threading.Thread(target=self.run)\n        thread.start()\n        return thread\n\n    def mark_episode_complete(self) -> None:\n        \"\"\"Marks the end of an episode.\"\"\"\n        self._in_episode = False\n\n    def _run_episode(self) -> None:\n        \"\"\"Runs a single episode.\"\"\"\n        logging.info(\"Starting episode...\")\n        self._environment.reset()\n        self._agent.reset()\n        for subscriber in self._subscribers:\n            subscriber.on_episode_start()\n\n        self._in_episode = True\n        self._episode_steps = 0\n        step_time = 1 / self._max_hz if self._max_hz > 0 else 0\n        last_step_time = time.time()\n\n        while self._in_episode:\n            self._step()\n            self._episode_steps += 1\n\n            # Sleep to maintain the desired frame rate\n            now = time.time()\n            dt = now - last_step_time\n            if dt < step_time:\n                time.sleep(step_time - dt)\n                last_step_time = time.time()\n            else:\n                last_step_time = now\n\n        logging.info(\"Episode completed.\")\n        for subscriber in self._subscribers:\n            subscriber.on_episode_end()\n\n    def _step(self) -> None:\n        \"\"\"A single step of the runtime loop.\"\"\"\n        observation = self._environment.get_observation()\n        action = self._agent.get_action(observation)\n        self._environment.apply_action(action)\n\n        for subscriber in self._subscribers:\n            subscriber.on_step(observation, action)\n\n        if self._environment.is_episode_complete() or (\n            self._max_episode_steps > 0 and self._episode_steps >= self._max_episode_steps\n        ):\n            self.mark_episode_complete()\n"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/runtime/subscriber.py",
    "content": "import abc\n\n\nclass Subscriber(abc.ABC):\n    \"\"\"Subscribes to events in the runtime.\n\n    Subscribers can be used to save data, visualize, etc.\n    \"\"\"\n\n    @abc.abstractmethod\n    def on_episode_start(self) -> None:\n        \"\"\"Called when an episode starts.\"\"\"\n\n    @abc.abstractmethod\n    def on_step(self, observation: dict, action: dict) -> None:\n        \"\"\"Append a step to the episode.\"\"\"\n\n    @abc.abstractmethod\n    def on_episode_end(self) -> None:\n        \"\"\"Called when an episode ends.\"\"\"\n"
  },
  {
    "path": "packages/openpi-client/src/openpi_client/websocket_client_policy.py",
    "content": "import logging\nimport time\nfrom typing import Dict, Optional, Tuple\n\nfrom typing_extensions import override\nimport websockets.sync.client\n\nfrom openpi_client import base_policy as _base_policy\nfrom openpi_client import msgpack_numpy\n\n\nclass WebsocketClientPolicy(_base_policy.BasePolicy):\n    \"\"\"Implements the Policy interface by communicating with a server over websocket.\n\n    See WebsocketPolicyServer for a corresponding server implementation.\n    \"\"\"\n\n    def __init__(self, host: str = \"0.0.0.0\", port: Optional[int] = None, api_key: Optional[str] = None) -> None:\n        if host.startswith(\"ws\"):\n            self._uri = host\n        else:\n            self._uri = f\"ws://{host}\"\n        if port is not None:\n            self._uri += f\":{port}\"\n        self._packer = msgpack_numpy.Packer()\n        self._api_key = api_key\n        self._ws, self._server_metadata = self._wait_for_server()\n\n    def get_server_metadata(self) -> Dict:\n        return self._server_metadata\n\n    def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]:\n        logging.info(f\"Waiting for server at {self._uri}...\")\n        while True:\n            try:\n                headers = {\"Authorization\": f\"Api-Key {self._api_key}\"} if self._api_key else None\n                conn = websockets.sync.client.connect(\n                    self._uri, compression=None, max_size=None, additional_headers=headers\n                )\n                metadata = msgpack_numpy.unpackb(conn.recv())\n                return conn, metadata\n            except ConnectionRefusedError:\n                logging.info(\"Still waiting for server...\")\n                time.sleep(5)\n\n    @override\n    def infer(self, obs: Dict) -> Dict:  # noqa: UP006\n        data = self._packer.pack(obs)\n        self._ws.send(data)\n        response = self._ws.recv()\n        if isinstance(response, str):\n            # we're expecting bytes; if the server sends a string, it's an error.\n            raise RuntimeError(f\"Error in inference server:\\n{response}\")\n        return msgpack_numpy.unpackb(response)\n\n    @override\n    def reset(self) -> None:\n        pass\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[project]\nname = \"openpi\"\nversion = \"0.1.0\"\ndescription = \"Physical Intelligence open source repo\"\nreadme = \"README.md\"\nrequires-python = \">=3.11\"\nlicense = { file = \"LICENSE\" }\ndependencies = [\n    \"augmax>=0.3.4\",\n    \"dm-tree>=0.1.8\",\n    \"einops>=0.8.0\",\n    \"equinox>=0.11.8\",\n    \"flatbuffers>=24.3.25\",\n    \"flax==0.10.2\",\n    \"fsspec[gcs]>=2024.6.0\",\n    \"gym-aloha>=0.1.1\",\n    \"imageio>=2.36.1\",\n    \"jax[cuda12]==0.5.3\",\n    \"jaxtyping==0.2.36\",\n    \"lerobot\",\n    \"ml_collections==1.0.0\",\n    \"numpy>=1.22.4,<2.0.0\",\n    \"numpydantic>=1.6.6\",\n    \"opencv-python>=4.10.0.84\",\n    \"openpi-client\",\n    \"orbax-checkpoint==0.11.13\",\n    \"pillow>=11.0.0\",\n    \"sentencepiece>=0.2.0\",\n    \"torch==2.7.1\",\n    \"tqdm-loggable>=0.2\",\n    \"typing-extensions>=4.12.2\",\n    \"tyro>=0.9.5\",\n    \"wandb>=0.19.1\",\n    \"filelock>=3.16.1\",\n    \"beartype==0.19.0\",\n    \"treescope>=0.1.7\",\n    \"transformers==4.53.2\",\n    \"rich>=14.0.0\",\n    \"polars>=1.30.0\",\n]\n\n\n[project.urls]\nRepository = \"https://github.com/Physical-Intelligence/openpi\"\n\n[dependency-groups]\ndev = [\n    \"pytest>=8.3.4\",\n    \"ruff>=0.8.6\",\n    \"pre-commit>=4.0.1\",\n    \"ipykernel>=6.29.5\",\n    \"ipywidgets>=8.1.5\",\n    \"matplotlib>=3.10.0\",\n    \"pynvml>=12.0.0\",\n]\nrlds = [\n    \"dlimp\",\n    \"tensorflow-cpu==2.15.0\",\n    \"tensorflow-datasets==4.9.9\",\n]\n\n[tool.uv]\noverride-dependencies = [\"ml-dtypes==0.4.1\", \"tensorstore==0.1.74\"]\n\n[tool.uv.sources]\nopenpi-client = { workspace = true }\nlerobot = { git = \"https://github.com/huggingface/lerobot\", rev = \"0cf864870cf29f4738d3ade893e6fd13fbd7cdb5\" }\ndlimp = { git = \"https://github.com/kvablack/dlimp\", rev = \"ad72ce3a9b414db2185bc0b38461d4101a65477a\" }\n\n[tool.uv.workspace]\nmembers = [\"packages/*\"]\n\n[tool.ruff]\nline-length = 120\ntarget-version = \"py311\"\nextend-exclude = [\"docker\", \"third_party\", \"src/openpi/models_pytorch/transformers_replace/*\"]\n\n[tool.ruff.lint]\n# https://docs.astral.sh/ruff/rules/\nselect = [\n    \"B\",\n    \"C4\",\n    \"DTZ\",\n    \"E4\",\n    \"E7\",\n    \"E9\",\n    \"F\",\n    \"FBT\",\n    \"FURB\",\n    \"I\",\n    \"ICN\",\n    \"ISC\",\n    \"LOG\",\n    \"N\",\n    \"PD\",\n    \"PERF\",\n    \"PIE\",\n    \"PLC\",\n    \"PLE\",\n    \"PLR1\",\n    \"PLR5\",\n    \"PLW\",\n    \"PT\",\n    \"Q\",\n    \"RET\",\n    \"RUF\",\n    \"SIM\",\n    \"SLF\",\n    \"T10\",\n    \"T20\",\n    \"UP\",\n    \"W\",\n]\nignore = [\n    \"F722\",   # Conflicts with array typing.\n    \"T201\",   # We use print statements.\n    \"PD008\",  # Lots of false positives.\n    \"ISC001\", # Disabling to support ruff format.\n    \"LOG015\", # Use logger.info.\n]\nunfixable = [\n    \"B905\", # Fix defaults to strict=False, which is not what we want.\n]\n\n[tool.ruff.lint.isort]\nforce-single-line = true\nforce-sort-within-sections = true\nsingle-line-exclusions = [\"collections.abc\", \"typing\", \"typing_extensions\"]\nknown-third-party = [\"wandb\"]\n\n[build-system]\nrequires = [\"hatchling\"]\nbuild-backend = \"hatchling.build\"\n\n[tool.pytest.ini_options]\nmarkers = [\"manual: should be run manually.\"]\ntestpaths = [\"src\", \"scripts\", \"packages\"]\n"
  },
  {
    "path": "scripts/__init__.py",
    "content": ""
  },
  {
    "path": "scripts/compute_norm_stats.py",
    "content": "\"\"\"Compute normalization statistics for a config.\n\nThis script is used to compute the normalization statistics for a given config. It\nwill compute the mean and standard deviation of the data in the dataset and save it\nto the config assets directory.\n\"\"\"\n\nimport numpy as np\nimport tqdm\nimport tyro\n\nimport openpi.models.model as _model\nimport openpi.shared.normalize as normalize\nimport openpi.training.config as _config\nimport openpi.training.data_loader as _data_loader\nimport openpi.transforms as transforms\n\n\nclass RemoveStrings(transforms.DataTransformFn):\n    def __call__(self, x: dict) -> dict:\n        return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}\n\n\ndef create_torch_dataloader(\n    data_config: _config.DataConfig,\n    action_horizon: int,\n    batch_size: int,\n    model_config: _model.BaseModelConfig,\n    num_workers: int,\n    max_frames: int | None = None,\n) -> tuple[_data_loader.Dataset, int]:\n    if data_config.repo_id is None:\n        raise ValueError(\"Data config must have a repo_id\")\n    dataset = _data_loader.create_torch_dataset(data_config, action_horizon, model_config)\n    dataset = _data_loader.TransformedDataset(\n        dataset,\n        [\n            *data_config.repack_transforms.inputs,\n            *data_config.data_transforms.inputs,\n            # Remove strings since they are not supported by JAX and are not needed to compute norm stats.\n            RemoveStrings(),\n        ],\n    )\n    if max_frames is not None and max_frames < len(dataset):\n        num_batches = max_frames // batch_size\n        shuffle = True\n    else:\n        num_batches = len(dataset) // batch_size\n        shuffle = False\n    data_loader = _data_loader.TorchDataLoader(\n        dataset,\n        local_batch_size=batch_size,\n        num_workers=num_workers,\n        shuffle=shuffle,\n        num_batches=num_batches,\n    )\n    return data_loader, num_batches\n\n\ndef create_rlds_dataloader(\n    data_config: _config.DataConfig,\n    action_horizon: int,\n    batch_size: int,\n    max_frames: int | None = None,\n) -> tuple[_data_loader.Dataset, int]:\n    dataset = _data_loader.create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=False)\n    dataset = _data_loader.IterableTransformedDataset(\n        dataset,\n        [\n            *data_config.repack_transforms.inputs,\n            *data_config.data_transforms.inputs,\n            # Remove strings since they are not supported by JAX and are not needed to compute norm stats.\n            RemoveStrings(),\n        ],\n        is_batched=True,\n    )\n    if max_frames is not None and max_frames < len(dataset):\n        num_batches = max_frames // batch_size\n    else:\n        # NOTE: this length is currently hard-coded for DROID.\n        num_batches = len(dataset) // batch_size\n    data_loader = _data_loader.RLDSDataLoader(\n        dataset,\n        num_batches=num_batches,\n    )\n    return data_loader, num_batches\n\n\ndef main(config_name: str, max_frames: int | None = None):\n    config = _config.get_config(config_name)\n    data_config = config.data.create(config.assets_dirs, config.model)\n\n    if data_config.rlds_data_dir is not None:\n        data_loader, num_batches = create_rlds_dataloader(\n            data_config, config.model.action_horizon, config.batch_size, max_frames\n        )\n    else:\n        data_loader, num_batches = create_torch_dataloader(\n            data_config, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames\n        )\n\n    keys = [\"state\", \"actions\"]\n    stats = {key: normalize.RunningStats() for key in keys}\n\n    for batch in tqdm.tqdm(data_loader, total=num_batches, desc=\"Computing stats\"):\n        for key in keys:\n            stats[key].update(np.asarray(batch[key]))\n\n    norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}\n\n    output_path = config.assets_dirs / data_config.repo_id\n    print(f\"Writing stats to: {output_path}\")\n    normalize.save(output_path, norm_stats)\n\n\nif __name__ == \"__main__\":\n    tyro.cli(main)\n"
  },
  {
    "path": "scripts/docker/compose.yml",
    "content": "# Run with:\n# docker compose -f scripts/docker/compose.yml up --build\nservices:\n  openpi_server:\n    image: openpi_server\n    build:\n      context: ../..\n      dockerfile: scripts/docker/serve_policy.Dockerfile\n    init: true\n    tty: true\n    network_mode: host\n    # Populate configured openpi data home to /openpi_assets inside the container.\n    # Populate aws credential inside the container.\n    volumes:\n      - $PWD:/app\n      - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets\n    environment:\n      - SERVER_ARGS\n      - OPENPI_DATA_HOME=/openpi_assets\n      - IS_DOCKER=true\n\n    # Comment out this block if not running on a machine with GPUs.\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              count: 1\n              capabilities: [gpu]\n"
  },
  {
    "path": "scripts/docker/install_docker_ubuntu22.sh",
    "content": "#!/bin/bash\n\n# Add Docker's official GPG key:\nsudo apt-get update\nsudo apt-get install -y ca-certificates curl\nsudo install -m 0755 -d /etc/apt/keyrings\nsudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc\nsudo chmod a+r /etc/apt/keyrings/docker.asc\n\n# Add the repository to Apt sources:\necho \\\n\t\"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \\\n  $(. /etc/os-release && echo \"$VERSION_CODENAME\") stable\" |\n\tsudo tee /etc/apt/sources.list.d/docker.list >/dev/null\nsudo apt-get update\n\nsudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin\n\n# Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc).\n# See https://docs.docker.com/engine/install/linux-postinstall/\nusername=$(whoami)\nsudo usermod -aG docker $username\n\n# Configure docker to start automatically on system boot.\nsudo systemctl enable docker.service\nsudo systemctl enable containerd.service\n\n# https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5\nif [ ~/.docker/config.json ]; then\n\tsed -i 's/credsStore/credStore/g' ~/.docker/config.json\nfi\n\necho \"\"\necho \"********************************************************************\"\necho \"**** Restart to allow Docker permission changes to take effect. ****\"\necho \"********************************************************************\"\necho \"\"\n"
  },
  {
    "path": "scripts/docker/install_nvidia_container_toolkit.sh",
    "content": "#!/bin/bash\n\n# Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs.\n# NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html\n\ncurl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg &&\n\tcurl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list |\n\tsed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' |\n\t\tsudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list\n\n# NVIDIA's documentation omits 'sudo' in the following command, but it is required.\nsudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list\nsudo apt-get update\nsudo apt-get install -y nvidia-container-toolkit\n\nsudo nvidia-ctk runtime configure --runtime=docker\nsudo systemctl restart docker\n"
  },
  {
    "path": "scripts/docker/serve_policy.Dockerfile",
    "content": "# Dockerfile for serving a PI policy.\n# Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container\n\n# Build the container:\n# docker build . -t openpi_server -f scripts/docker/serve_policy.Dockerfile\n\n# Run the container:\n# docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash\n\nFROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0\nCOPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/\n\nWORKDIR /app\n\n# Needed because LeRobot uses git-lfs.\nRUN apt-get update && apt-get install -y git git-lfs linux-headers-generic build-essential clang\n\n# Copy from the cache instead of linking since it's a mounted volume\nENV UV_LINK_MODE=copy\n\n# Write the virtual environment outside of the project directory so it doesn't\n# leak out of the container when we mount the application code.\nENV UV_PROJECT_ENVIRONMENT=/.venv\n\n# Install the project's dependencies using the lockfile and settings\nRUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT\nRUN --mount=type=cache,target=/root/.cache/uv \\\n    --mount=type=bind,source=uv.lock,target=uv.lock \\\n    --mount=type=bind,source=pyproject.toml,target=pyproject.toml \\\n    --mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \\\n    --mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \\\n    GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev\n\n# Copy transformers_replace files while preserving directory structure\nCOPY src/openpi/models_pytorch/transformers_replace/ /tmp/transformers_replace/\nRUN /.venv/bin/python -c \"import transformers; print(transformers.__file__)\" | xargs dirname | xargs -I{} cp -r /tmp/transformers_replace/* {} && rm -rf /tmp/transformers_replace\n\nCMD /bin/bash -c \"uv run scripts/serve_policy.py $SERVER_ARGS\"\n"
  },
  {
    "path": "scripts/serve_policy.py",
    "content": "import dataclasses\nimport enum\nimport logging\nimport socket\n\nimport tyro\n\nfrom openpi.policies import policy as _policy\nfrom openpi.policies import policy_config as _policy_config\nfrom openpi.serving import websocket_policy_server\nfrom openpi.training import config as _config\n\n\nclass EnvMode(enum.Enum):\n    \"\"\"Supported environments.\"\"\"\n\n    ALOHA = \"aloha\"\n    ALOHA_SIM = \"aloha_sim\"\n    DROID = \"droid\"\n    LIBERO = \"libero\"\n\n\n@dataclasses.dataclass\nclass Checkpoint:\n    \"\"\"Load a policy from a trained checkpoint.\"\"\"\n\n    # Training config name (e.g., \"pi0_aloha_sim\").\n    config: str\n    # Checkpoint directory (e.g., \"checkpoints/pi0_aloha_sim/exp/10000\").\n    dir: str\n\n\n@dataclasses.dataclass\nclass Default:\n    \"\"\"Use the default policy for the given environment.\"\"\"\n\n\n@dataclasses.dataclass\nclass Args:\n    \"\"\"Arguments for the serve_policy script.\"\"\"\n\n    # Environment to serve the policy for. This is only used when serving default policies.\n    env: EnvMode = EnvMode.ALOHA_SIM\n\n    # If provided, will be used in case the \"prompt\" key is not present in the data, or if the model doesn't have a default\n    # prompt.\n    default_prompt: str | None = None\n\n    # Port to serve the policy on.\n    port: int = 8000\n    # Record the policy's behavior for debugging.\n    record: bool = False\n\n    # Specifies how to load the policy. If not provided, the default policy for the environment will be used.\n    policy: Checkpoint | Default = dataclasses.field(default_factory=Default)\n\n\n# Default checkpoints that should be used for each environment.\nDEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {\n    EnvMode.ALOHA: Checkpoint(\n        config=\"pi05_aloha\",\n        dir=\"gs://openpi-assets/checkpoints/pi05_base\",\n    ),\n    EnvMode.ALOHA_SIM: Checkpoint(\n        config=\"pi0_aloha_sim\",\n        dir=\"gs://openpi-assets/checkpoints/pi0_aloha_sim\",\n    ),\n    EnvMode.DROID: Checkpoint(\n        config=\"pi05_droid\",\n        dir=\"gs://openpi-assets/checkpoints/pi05_droid\",\n    ),\n    EnvMode.LIBERO: Checkpoint(\n        config=\"pi05_libero\",\n        dir=\"gs://openpi-assets/checkpoints/pi05_libero\",\n    ),\n}\n\n\ndef create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy:\n    \"\"\"Create a default policy for the given environment.\"\"\"\n    if checkpoint := DEFAULT_CHECKPOINT.get(env):\n        return _policy_config.create_trained_policy(\n            _config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt\n        )\n    raise ValueError(f\"Unsupported environment mode: {env}\")\n\n\ndef create_policy(args: Args) -> _policy.Policy:\n    \"\"\"Create a policy from the given arguments.\"\"\"\n    match args.policy:\n        case Checkpoint():\n            return _policy_config.create_trained_policy(\n                _config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt\n            )\n        case Default():\n            return create_default_policy(args.env, default_prompt=args.default_prompt)\n\n\ndef main(args: Args) -> None:\n    policy = create_policy(args)\n    policy_metadata = policy.metadata\n\n    # Record the policy's behavior.\n    if args.record:\n        policy = _policy.PolicyRecorder(policy, \"policy_records\")\n\n    hostname = socket.gethostname()\n    local_ip = socket.gethostbyname(hostname)\n    logging.info(\"Creating server (host: %s, ip: %s)\", hostname, local_ip)\n\n    server = websocket_policy_server.WebsocketPolicyServer(\n        policy=policy,\n        host=\"0.0.0.0\",\n        port=args.port,\n        metadata=policy_metadata,\n    )\n    server.serve_forever()\n\n\nif __name__ == \"__main__\":\n    logging.basicConfig(level=logging.INFO, force=True)\n    main(tyro.cli(Args))\n"
  },
  {
    "path": "scripts/train.py",
    "content": "import dataclasses\nimport functools\nimport logging\nimport platform\nfrom typing import Any\n\nimport etils.epath as epath\nimport flax.nnx as nnx\nfrom flax.training import common_utils\nimport flax.traverse_util as traverse_util\nimport jax\nimport jax.experimental\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\nimport tqdm_loggable.auto as tqdm\nimport wandb\n\nimport openpi.models.model as _model\nimport openpi.shared.array_typing as at\nimport openpi.shared.nnx_utils as nnx_utils\nimport openpi.training.checkpoints as _checkpoints\nimport openpi.training.config as _config\nimport openpi.training.data_loader as _data_loader\nimport openpi.training.optimizer as _optimizer\nimport openpi.training.sharding as sharding\nimport openpi.training.utils as training_utils\nimport openpi.training.weight_loaders as _weight_loaders\n\n\ndef init_logging():\n    \"\"\"Custom logging format for better readability.\"\"\"\n    level_mapping = {\"DEBUG\": \"D\", \"INFO\": \"I\", \"WARNING\": \"W\", \"ERROR\": \"E\", \"CRITICAL\": \"C\"}\n\n    class CustomFormatter(logging.Formatter):\n        def format(self, record):\n            record.levelname = level_mapping.get(record.levelname, record.levelname)\n            return super().format(record)\n\n    formatter = CustomFormatter(\n        fmt=\"%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)\",\n        datefmt=\"%H:%M:%S\",\n    )\n\n    logger = logging.getLogger()\n    logger.setLevel(logging.INFO)\n    logger.handlers[0].setFormatter(formatter)\n\n\ndef init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):\n    if not enabled:\n        wandb.init(mode=\"disabled\")\n        return\n\n    ckpt_dir = config.checkpoint_dir\n    if not ckpt_dir.exists():\n        raise FileNotFoundError(f\"Checkpoint directory {ckpt_dir} does not exist.\")\n    if resuming:\n        run_id = (ckpt_dir / \"wandb_id.txt\").read_text().strip()\n        wandb.init(id=run_id, resume=\"must\", project=config.project_name)\n    else:\n        wandb.init(\n            name=config.exp_name,\n            config=dataclasses.asdict(config),\n            project=config.project_name,\n        )\n        (ckpt_dir / \"wandb_id.txt\").write_text(wandb.run.id)\n\n    if log_code:\n        wandb.run.log_code(epath.Path(__file__).parent.parent)\n\n\ndef _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params:\n    \"\"\"Loads and validates the weights. Returns a loaded subset of the weights.\"\"\"\n    loaded_params = loader.load(params_shape)\n    at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True)\n\n    # Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned.\n    return traverse_util.unflatten_dict(\n        {k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)}\n    )\n\n\n@at.typecheck\ndef init_train_state(\n    config: _config.TrainConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, *, resume: bool\n) -> tuple[training_utils.TrainState, Any]:\n    tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None)\n\n    def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState:\n        rng, model_rng = jax.random.split(rng)\n        # initialize the model (and its parameters).\n        model = config.model.create(model_rng)\n\n        # Merge the partial params into the model.\n        if partial_params is not None:\n            graphdef, state = nnx.split(model)\n            # This will produce an error if the partial params are not a subset of the state.\n            state.replace_by_pure_dict(partial_params)\n            model = nnx.merge(graphdef, state)\n\n        params = nnx.state(model)\n        # Convert frozen params to bfloat16.\n        params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16)))\n\n        return training_utils.TrainState(\n            step=0,\n            params=params,\n            model_def=nnx.graphdef(model),\n            tx=tx,\n            opt_state=tx.init(params.filter(config.trainable_filter)),\n            ema_decay=config.ema_decay,\n            ema_params=None if config.ema_decay is None else params,\n        )\n\n    train_state_shape = jax.eval_shape(init, init_rng)\n    state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)\n\n    if resume:\n        return train_state_shape, state_sharding\n\n    partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict())\n    replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())\n\n    # Initialize the train state and mix in the partial params.\n    train_state = jax.jit(\n        init,\n        donate_argnums=(1,),  # donate the partial params buffer.\n        in_shardings=replicated_sharding,\n        out_shardings=state_sharding,\n    )(init_rng, partial_params)\n\n    return train_state, state_sharding\n\n\n@at.typecheck\ndef train_step(\n    config: _config.TrainConfig,\n    rng: at.KeyArrayLike,\n    state: training_utils.TrainState,\n    batch: tuple[_model.Observation, _model.Actions],\n) -> tuple[training_utils.TrainState, dict[str, at.Array]]:\n    model = nnx.merge(state.model_def, state.params)\n    model.train()\n\n    @at.typecheck\n    def loss_fn(\n        model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions\n    ):\n        chunked_loss = model.compute_loss(rng, observation, actions, train=True)\n        return jnp.mean(chunked_loss)\n\n    train_rng = jax.random.fold_in(rng, state.step)\n    observation, actions = batch\n\n    # Filter out frozen params.\n    diff_state = nnx.DiffState(0, config.trainable_filter)\n    loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions)\n\n    params = state.params.filter(config.trainable_filter)\n    updates, new_opt_state = state.tx.update(grads, state.opt_state, params)\n    new_params = optax.apply_updates(params, updates)\n\n    # Update the model in place and return the new full state.\n    nnx.update(model, new_params)\n    new_params = nnx.state(model)\n\n    new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state)\n    if state.ema_decay is not None:\n        new_state = dataclasses.replace(\n            new_state,\n            ema_params=jax.tree.map(\n                lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params\n            ),\n        )\n\n    # Filter out params that aren't kernels.\n    kernel_params = nnx.state(\n        model,\n        nnx.All(\n            nnx.Param,\n            nnx.Not(nnx_utils.PathRegex(\".*/(bias|scale|pos_embedding|input_embedding)\")),\n            lambda _, x: x.value.ndim > 1,\n        ),\n    )\n    info = {\n        \"loss\": loss,\n        \"grad_norm\": optax.global_norm(grads),\n        \"param_norm\": optax.global_norm(kernel_params),\n    }\n    return new_state, info\n\n\ndef main(config: _config.TrainConfig):\n    init_logging()\n    logging.info(f\"Running on: {platform.node()}\")\n\n    if config.batch_size % jax.device_count() != 0:\n        raise ValueError(\n            f\"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}.\"\n        )\n\n    jax.config.update(\"jax_compilation_cache_dir\", str(epath.Path(\"~/.cache/jax\").expanduser()))\n\n    rng = jax.random.key(config.seed)\n    train_rng, init_rng = jax.random.split(rng)\n\n    mesh = sharding.make_mesh(config.fsdp_devices)\n    data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))\n    replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())\n\n    checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(\n        config.checkpoint_dir,\n        keep_period=config.keep_period,\n        overwrite=config.overwrite,\n        resume=config.resume,\n    )\n    init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)\n\n    data_loader = _data_loader.create_data_loader(\n        config,\n        sharding=data_sharding,\n        shuffle=True,\n    )\n    data_iter = iter(data_loader)\n    batch = next(data_iter)\n    logging.info(f\"Initialized data loader:\\n{training_utils.array_tree_to_info(batch)}\")\n\n    # Log images from first batch to sanity check.\n    images_to_log = [\n        wandb.Image(np.concatenate([np.array(img[i]) for img in batch[0].images.values()], axis=1))\n        for i in range(min(5, len(next(iter(batch[0].images.values())))))\n    ]\n    wandb.log({\"camera_views\": images_to_log}, step=0)\n\n    train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming)\n    jax.block_until_ready(train_state)\n    logging.info(f\"Initialized train state:\\n{training_utils.array_tree_to_info(train_state.params)}\")\n\n    if resuming:\n        train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)\n\n    ptrain_step = jax.jit(\n        functools.partial(train_step, config),\n        in_shardings=(replicated_sharding, train_state_sharding, data_sharding),\n        out_shardings=(train_state_sharding, replicated_sharding),\n        donate_argnums=(1,),\n    )\n\n    start_step = int(train_state.step)\n    pbar = tqdm.tqdm(\n        range(start_step, config.num_train_steps),\n        initial=start_step,\n        total=config.num_train_steps,\n        dynamic_ncols=True,\n    )\n\n    infos = []\n    for step in pbar:\n        with sharding.set_mesh(mesh):\n            train_state, info = ptrain_step(train_rng, train_state, batch)\n        infos.append(info)\n        if step % config.log_interval == 0:\n            stacked_infos = common_utils.stack_forest(infos)\n            reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))\n            info_str = \", \".join(f\"{k}={v:.4f}\" for k, v in reduced_info.items())\n            pbar.write(f\"Step {step}: {info_str}\")\n            wandb.log(reduced_info, step=step)\n            infos = []\n        batch = next(data_iter)\n\n        if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1:\n            _checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)\n\n    logging.info(\"Waiting for checkpoint manager to finish\")\n    checkpoint_manager.wait_until_finished()\n\n\nif __name__ == \"__main__\":\n    main(_config.cli())\n"
  },
  {
    "path": "scripts/train_pytorch.py",
    "content": "\"\"\"\nPyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support.\nThis script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs\nentirely in PyTorch using the `PI0Pytorch` model and your existing config/data\npipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`.\n\nUsage\nSingle GPU:\n  python scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>\n  Example:\n  python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test\n  python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume  # Resume from latest checkpoint\nMulti-GPU (single node):\n  torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name>\n  Example:\n  torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test\n  torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume\nMulti-Node Training:\n\ttorchrun \\\n    --nnodes=<num_nodes> --nproc_per_node=<gpus_per_node> --node_rank=<rank_of_node> \\\n    --master_addr=<master_ip> --master_port=<port> \\\n    scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>\n\n\"\"\"\n\nimport dataclasses\nimport gc\nimport logging\nimport os\nimport platform\nimport shutil\nimport time\n\nimport jax\nimport numpy as np\nimport safetensors.torch\nimport torch\nimport torch.distributed as dist\nimport torch.nn.parallel\nimport tqdm\nimport wandb\n\nimport openpi.models.pi0_config\nimport openpi.models_pytorch.pi0_pytorch\nimport openpi.shared.normalize as _normalize\nimport openpi.training.config as _config\nimport openpi.training.data_loader as _data\n\n\ndef init_logging():\n    level_mapping = {\"DEBUG\": \"D\", \"INFO\": \"I\", \"WARNING\": \"W\", \"ERROR\": \"E\", \"CRITICAL\": \"C\"}\n\n    class CustomFormatter(logging.Formatter):\n        def format(self, record):\n            record.levelname = level_mapping.get(record.levelname, record.levelname)\n            return super().format(record)\n\n    formatter = CustomFormatter(\n        fmt=\"%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)\",\n        datefmt=\"%H:%M:%S\",\n    )\n    logger = logging.getLogger()\n    logger.setLevel(logging.INFO)\n    if not logger.handlers:\n        ch = logging.StreamHandler()\n        ch.setFormatter(formatter)\n        logger.addHandler(ch)\n    else:\n        logger.handlers[0].setFormatter(formatter)\n\n\ndef init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True):\n    \"\"\"Initialize wandb logging.\"\"\"\n    if not enabled:\n        wandb.init(mode=\"disabled\")\n        return\n\n    ckpt_dir = config.checkpoint_dir\n    if not ckpt_dir.exists():\n        raise FileNotFoundError(f\"Checkpoint directory {ckpt_dir} does not exist.\")\n\n    if resuming:\n        run_id = (ckpt_dir / \"wandb_id.txt\").read_text().strip()\n        wandb.init(id=run_id, resume=\"must\", project=config.project_name)\n    else:\n        wandb.init(\n            name=config.exp_name,\n            config=dataclasses.asdict(config),\n            project=config.project_name,\n        )\n        (ckpt_dir / \"wandb_id.txt\").write_text(wandb.run.id)\n\n\ndef setup_ddp():\n    world_size = int(os.environ.get(\"WORLD_SIZE\", \"1\"))\n    use_ddp = world_size > 1\n    if use_ddp and not torch.distributed.is_initialized():\n        backend = \"nccl\" if torch.cuda.is_available() else \"gloo\"\n        torch.distributed.init_process_group(backend=backend, init_method=\"env://\")\n\n        # Set up debugging environment variables for DDP issues\n        if os.environ.get(\"TORCH_DISTRIBUTED_DEBUG\") is None:\n            os.environ[\"TORCH_DISTRIBUTED_DEBUG\"] = \"INFO\"\n\n    local_rank = int(os.environ.get(\"LOCAL_RANK\", os.environ.get(\"RANK\", \"0\")))\n    device = torch.device(f\"cuda:{local_rank}\" if torch.cuda.is_available() else \"cpu\")\n    if torch.cuda.is_available():\n        torch.cuda.set_device(device)\n    return use_ddp, local_rank, device\n\n\ndef cleanup_ddp():\n    if torch.distributed.is_initialized():\n        torch.distributed.barrier()\n        torch.distributed.destroy_process_group()\n\n\ndef set_seed(seed: int, local_rank: int):\n    torch.manual_seed(seed + local_rank)\n    np.random.seed(seed + local_rank)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed_all(seed + local_rank)\n\n\ndef build_datasets(config: _config.TrainConfig):\n    # Use the unified data loader with PyTorch framework\n    data_loader = _data.create_data_loader(config, framework=\"pytorch\", shuffle=True)\n    return data_loader, data_loader.data_config()\n\n\ndef get_model_state_dict(model):\n    \"\"\"Get state dict from model, handling DDP wrapper.\"\"\"\n    return (\n        model.module.state_dict()\n        if isinstance(model, torch.nn.parallel.DistributedDataParallel)\n        else model.state_dict()\n    )\n\n\ndef get_model_parameters(model):\n    \"\"\"Get parameters from model, handling DDP wrapper.\"\"\"\n    return (\n        model.module.parameters()\n        if isinstance(model, torch.nn.parallel.DistributedDataParallel)\n        else model.parameters()\n    )\n\n\ndef save_checkpoint(model, optimizer, global_step, config, is_main, data_config):\n    \"\"\"Save a checkpoint with model state, optimizer state, and metadata.\"\"\"\n    if not is_main:\n        return\n\n    # Only save if it's time to save or if it's the final step\n    if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1:\n        # Create temporary directory for atomic checkpoint saving\n        final_ckpt_dir = config.checkpoint_dir / f\"{global_step}\"\n        tmp_ckpt_dir = config.checkpoint_dir / f\"tmp_{global_step}\"\n\n        # Remove any existing temp directory and create new one\n        if tmp_ckpt_dir.exists():\n            shutil.rmtree(tmp_ckpt_dir)\n        tmp_ckpt_dir.mkdir(parents=True, exist_ok=True)\n\n        # Save model state using safetensors (handle shared tensors)\n        model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model\n        safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / \"model.safetensors\")\n\n        # Save optimizer state using PyTorch format\n        torch.save(optimizer.state_dict(), tmp_ckpt_dir / \"optimizer.pt\")\n\n        # Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues)\n        metadata = {\n            \"global_step\": global_step,\n            \"config\": dataclasses.asdict(config),\n            \"timestamp\": time.time(),\n        }\n        torch.save(metadata, tmp_ckpt_dir / \"metadata.pt\")\n\n        # save norm stats\n        norm_stats = data_config.norm_stats\n        if norm_stats is not None and data_config.asset_id is not None:\n            _normalize.save(tmp_ckpt_dir / \"assets\" / data_config.asset_id, norm_stats)\n\n        # Atomically move temp directory to final location\n        if final_ckpt_dir.exists():\n            shutil.rmtree(final_ckpt_dir)\n        tmp_ckpt_dir.rename(final_ckpt_dir)\n\n        logging.info(f\"Saved checkpoint at step {global_step} -> {final_ckpt_dir}\")\n\n        # Log checkpoint to wandb\n        if config.wandb_enabled:\n            wandb.log({\"checkpoint_step\": global_step}, step=global_step)\n\n\ndef load_checkpoint(model, optimizer, checkpoint_dir, device):\n    \"\"\"Load the latest checkpoint and return the global step.\"\"\"\n    checkpoint_steps = [\n        int(d.name)\n        for d in checkpoint_dir.iterdir()\n        if d.is_dir() and d.name.isdigit() and not d.name.startswith(\"tmp_\")\n    ]\n\n    if not checkpoint_steps:\n        raise FileNotFoundError(f\"No checkpoints found in {checkpoint_dir}\")\n\n    latest_step = max(checkpoint_steps)\n    ckpt_dir = checkpoint_dir / f\"{latest_step}\"\n\n    # Clear memory before loading checkpoints\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n        gc.collect()\n        log_memory_usage(device, latest_step, \"before_loading_checkpoint\")\n\n    try:\n        # Load model state with error handling\n        logging.info(\"Loading model state...\")\n        safetensors_path = ckpt_dir / \"model.safetensors\"\n\n        if safetensors_path.exists():\n            model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model\n            safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device))\n            logging.info(\"Loaded model state from safetensors format\")\n        else:\n            raise FileNotFoundError(f\"No model checkpoint found at {ckpt_dir}\")\n\n        torch.cuda.empty_cache()\n        gc.collect()\n        log_memory_usage(device, latest_step, \"after_loading_model\")\n\n        # Load optimizer state with error handling\n        logging.info(\"Loading optimizer state...\")\n        optimizer_path = ckpt_dir / \"optimizer.pt\"\n\n        if optimizer_path.exists():\n            optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False)\n            logging.info(\"Loaded optimizer state from pt format\")\n        else:\n            raise FileNotFoundError(f\"No optimizer checkpoint found at {ckpt_dir}\")\n\n        optimizer.load_state_dict(optimizer_state_dict)\n        del optimizer_state_dict\n        torch.cuda.empty_cache()\n        gc.collect()\n        log_memory_usage(device, latest_step, \"after_loading_optimizer\")\n\n        # Load metadata\n        logging.info(\"Loading metadata...\")\n        metadata = torch.load(ckpt_dir / \"metadata.pt\", map_location=device, weights_only=False)\n        global_step = metadata.get(\"global_step\", latest_step)\n        del metadata\n        torch.cuda.empty_cache()\n        gc.collect()\n        log_memory_usage(device, latest_step, \"after_loading_metadata\")\n\n        logging.info(f\"Successfully loaded all checkpoint components from step {latest_step}\")\n        return global_step\n\n    except RuntimeError as e:\n        if \"out of memory\" in str(e):\n            # Clear memory and provide detailed error message\n            torch.cuda.empty_cache()\n            gc.collect()\n            logging.error(f\"Out of memory error while loading checkpoint: {e!s}\")\n            log_memory_usage(device, latest_step, \"after_oom_error\")\n            raise RuntimeError(\n                \"Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True\"\n            ) from e\n        raise\n\n\ndef get_latest_checkpoint_step(checkpoint_dir):\n    \"\"\"Get the latest checkpoint step number from a checkpoint directory.\"\"\"\n    checkpoint_steps = [\n        int(d.name)\n        for d in checkpoint_dir.iterdir()\n        if d.is_dir() and d.name.isdigit() and not d.name.startswith(\"tmp_\")\n    ]\n    return max(checkpoint_steps) if checkpoint_steps else None\n\n\ndef log_memory_usage(device, step, phase=\"unknown\"):\n    \"\"\"Log detailed memory usage information.\"\"\"\n    if not torch.cuda.is_available():\n        return\n\n    memory_allocated = torch.cuda.memory_allocated(device) / 1e9\n    memory_reserved = torch.cuda.memory_reserved(device) / 1e9\n    memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device)\n    memory_free = memory_free / 1e9\n\n    # Get more detailed memory info\n    memory_stats = torch.cuda.memory_stats(device)\n    max_memory_allocated = memory_stats.get(\"allocated_bytes.all.peak\", 0) / 1e9\n    max_memory_reserved = memory_stats.get(\"reserved_bytes.all.peak\", 0) / 1e9\n\n    # Get DDP info if available\n    ddp_info = \"\"\n    if dist.is_initialized():\n        ddp_info = f\" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}\"\n\n    logging.info(\n        f\"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}\"\n    )\n\n\ndef train_loop(config: _config.TrainConfig):\n    use_ddp, local_rank, device = setup_ddp()\n    is_main = (not use_ddp) or (dist.get_rank() == 0)\n    set_seed(config.seed, local_rank)\n\n    # Initialize checkpoint directory and wandb\n    resuming = False\n    if config.resume:\n        # Find checkpoint directory based on experiment name\n        exp_checkpoint_dir = config.checkpoint_dir\n        if exp_checkpoint_dir.exists():\n            # Use validation to find the latest working checkpoint\n            latest_step = get_latest_checkpoint_step(exp_checkpoint_dir)\n            if latest_step is not None:\n                resuming = True\n                logging.info(\n                    f\"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}\"\n                )\n            else:\n                raise FileNotFoundError(f\"No valid checkpoints found in {exp_checkpoint_dir} for resume\")\n        else:\n            raise FileNotFoundError(f\"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume\")\n    elif config.overwrite and config.checkpoint_dir.exists():\n        shutil.rmtree(config.checkpoint_dir)\n        logging.info(f\"Overwriting checkpoint directory: {config.checkpoint_dir}\")\n\n    # Create checkpoint directory with experiment name\n    if not resuming:\n        # For new runs, create experiment-specific checkpoint directory\n        exp_checkpoint_dir = config.checkpoint_dir\n        exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)\n        logging.info(f\"Created experiment checkpoint directory: {exp_checkpoint_dir}\")\n    else:\n        # For resume, checkpoint_dir is already set to the experiment directory\n        logging.info(f\"Using existing experiment checkpoint directory: {config.checkpoint_dir}\")\n\n    # Initialize wandb (only on main process)\n    if is_main:\n        init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)\n\n    # Build data loader using the unified data loader\n    # Calculate effective batch size per GPU for DDP\n    # For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size\n    world_size = torch.distributed.get_world_size() if use_ddp else 1\n    effective_batch_size = config.batch_size // world_size\n    logging.info(\n        f\"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})\"\n    )\n\n    # Pass the original batch size to data loader - it will handle DDP splitting internally\n    loader, data_config = build_datasets(config)\n\n    # Log sample images to wandb on first batch\n    if is_main and config.wandb_enabled and not resuming:\n        # Create a separate data loader for sample batch to avoid consuming the main loader\n        sample_data_loader = _data.create_data_loader(config, framework=\"pytorch\", shuffle=False)\n        sample_batch = next(iter(sample_data_loader))\n        # Convert observation and actions to torch tensors\n        observation, actions = sample_batch\n        sample_batch = observation.to_dict()\n        sample_batch[\"actions\"] = actions\n\n        # Create sample images for wandb\n        images_to_log = []\n        # Get batch size from the first image tensor\n        batch_size = next(iter(sample_batch[\"image\"].values())).shape[0]\n        for i in range(min(5, batch_size)):\n            # Concatenate all camera views horizontally for this batch item\n            # Convert from NCHW to NHWC format for wandb\n            img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch[\"image\"].values()], axis=1)\n            img_concatenated = img_concatenated.cpu().numpy()\n            images_to_log.append(wandb.Image(img_concatenated))\n\n        wandb.log({\"camera_views\": images_to_log}, step=0)\n\n        # Clear sample batch from memory aggressively\n        del sample_batch, observation, actions, images_to_log, img_concatenated\n        del sample_data_loader  # Also delete the sample data loader\n        gc.collect()\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n        logging.info(\"Cleared sample batch and data loader from memory\")\n\n    # Build model\n    if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):\n        # Convert dataclass to Pi0Config if needed\n        model_cfg = openpi.models.pi0_config.Pi0Config(\n            dtype=config.pytorch_training_precision,\n            action_dim=config.model.action_dim,\n            action_horizon=config.model.action_horizon,\n            max_token_len=config.model.max_token_len,\n            paligemma_variant=getattr(config.model, \"paligemma_variant\", \"gemma_2b\"),\n            action_expert_variant=getattr(config.model, \"action_expert_variant\", \"gemma_300m\"),\n            pi05=getattr(config.model, \"pi05\", False),\n        )\n    else:\n        model_cfg = config.model\n        # Update dtype to match pytorch_training_precision\n        object.__setattr__(model_cfg, \"dtype\", config.pytorch_training_precision)\n\n    model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device)\n\n    if hasattr(model, \"gradient_checkpointing_enable\"):\n        enable_gradient_checkpointing = True\n        model.gradient_checkpointing_enable()\n        logging.info(\"Enabled gradient checkpointing for memory optimization\")\n    else:\n        enable_gradient_checkpointing = False\n        logging.info(\"Gradient checkpointing is not supported for this model\")\n\n    # Log initial memory usage after model creation\n    if is_main and torch.cuda.is_available():\n        log_memory_usage(device, 0, \"after_model_creation\")\n\n    # Enable memory optimizations for large-scale training\n    if world_size >= 8:\n        torch.backends.cudnn.benchmark = True\n        torch.backends.cuda.matmul.allow_tf32 = True\n        torch.backends.cudnn.allow_tf32 = True\n        # Set memory allocation configuration\n        os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"max_split_size_mb:128,expandable_segments:True\"\n        logging.info(\"Enabled memory optimizations for 8+ GPU training\")\n\n    if use_ddp:\n        model = torch.nn.parallel.DistributedDataParallel(\n            model,\n            device_ids=[device.index] if device.type == \"cuda\" else None,\n            find_unused_parameters=True,  # Disable for memory efficiency\n            gradient_as_bucket_view=True,  # Enable for memory efficiency\n            static_graph=world_size >= 8,  # Enable for 8+ GPUs\n        )\n\n    # Load weights from weight_loader if specified (for fine-tuning)\n    if config.pytorch_weight_path is not None:\n        logging.info(f\"Loading weights from: {config.pytorch_weight_path}\")\n\n        model_path = os.path.join(config.pytorch_weight_path, \"model.safetensors\")\n        safetensors.torch.load_model(\n            (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path\n        )\n        logging.info(f\"Loaded PyTorch weights from {config.pytorch_weight_path}\")\n\n    # Optimizer + learning rate schedule from config\n    warmup_steps = config.lr_schedule.warmup_steps\n    peak_lr = config.lr_schedule.peak_lr\n    decay_steps = config.lr_schedule.decay_steps\n    end_lr = config.lr_schedule.decay_lr\n\n    # Create optimizer with config parameters\n    optim = torch.optim.AdamW(\n        model.parameters(),\n        lr=peak_lr,\n        betas=(config.optimizer.b1, config.optimizer.b2),\n        eps=config.optimizer.eps,\n        weight_decay=config.optimizer.weight_decay,\n    )\n\n    # Load checkpoint if resuming\n    global_step = 0\n    if resuming:\n        global_step = load_checkpoint(model, optim, config.checkpoint_dir, device)\n        logging.info(f\"Resumed training from step {global_step}\")\n\n    def lr_schedule(step: int):\n        if step < warmup_steps:\n            # Match JAX behavior: start from peak_lr / (warmup_steps + 1)\n            init_lr = peak_lr / (warmup_steps + 1)\n            return init_lr + (peak_lr - init_lr) * step / warmup_steps\n        # cosine decay\n        progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps))\n        cos = 0.5 * (1 + np.cos(np.pi * progress))\n        return end_lr + (peak_lr - end_lr) * cos\n\n    model.train()\n    start_time = time.time()\n    infos = []  # Collect stats over log interval\n    if is_main:\n        logging.info(\n            f\"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}\"\n        )\n        logging.info(\n            f\"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}\"\n        )\n        logging.info(f\"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}\")\n        logging.info(\n            f\"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}\"\n        )\n        logging.info(\n            f\"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}\"\n        )\n        logging.info(\"EMA is not supported for PyTorch training\")\n        logging.info(f\"Training precision: {model_cfg.dtype}\")\n\n    # Training loop - iterate until we reach num_train_steps\n    pbar = (\n        tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc=\"Training\", disable=not is_main)\n        if is_main\n        else None\n    )\n\n    while global_step < config.num_train_steps:\n        # Set epoch for distributed training\n        if use_ddp and hasattr(loader, \"set_epoch\"):\n            loader.set_epoch(global_step // len(loader))\n\n        for observation, actions in loader:\n            # Check if we've reached the target number of steps\n            if global_step >= config.num_train_steps:\n                break\n\n            # The unified data loader returns (observation, actions) tuple\n            observation = jax.tree.map(lambda x: x.to(device), observation)  # noqa: PLW2901\n            actions = actions.to(torch.float32)  # noqa: PLW2901\n            actions = actions.to(device)  # noqa: PLW2901\n\n            # Update LR\n            for pg in optim.param_groups:\n                pg[\"lr\"] = lr_schedule(global_step)\n\n            # Forward pass\n            losses = model(observation, actions)\n            # Ensure losses is a tensor and handle different return types\n            if isinstance(losses, list | tuple):\n                losses = torch.stack(losses)\n            elif not isinstance(losses, torch.Tensor):\n                losses = torch.tensor(losses, device=device, dtype=torch.float32)\n\n            loss = losses.mean()\n\n            # Backward pass\n            loss.backward()\n\n            # Log memory usage after backward pass\n            if global_step < 5 and is_main and torch.cuda.is_available():\n                log_memory_usage(device, global_step, \"after_backward\")\n\n            # Gradient clipping\n            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm)\n\n            # Optimizer step\n            optim.step()\n            optim.zero_grad(set_to_none=True)\n\n            # Clear gradients more aggressively\n            for param in model.parameters():\n                if param.grad is not None:\n                    param.grad.detach_()\n                    param.grad = None\n\n            # Collect stats\n            if is_main:\n                infos.append(\n                    {\n                        \"loss\": loss.item(),\n                        \"learning_rate\": optim.param_groups[0][\"lr\"],\n                        \"grad_norm\": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm,\n                    }\n                )\n\n            if is_main and (global_step % config.log_interval == 0):\n                elapsed = time.time() - start_time\n\n                # Average stats over log interval\n                avg_loss = sum(info[\"loss\"] for info in infos) / len(infos)\n                avg_lr = sum(info[\"learning_rate\"] for info in infos) / len(infos)\n\n                avg_grad_norm = None\n                if any(\"grad_norm\" in info for info in infos):\n                    vals = [\n                        info[\"grad_norm\"] for info in infos if \"grad_norm\" in info and info[\"grad_norm\"] is not None\n                    ]\n                    if len(vals) > 0:\n                        avg_grad_norm = sum(vals) / len(vals)\n                logging.info(\n                    f\"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s\"\n                    if avg_grad_norm is not None\n                    else f\"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s\"\n                )\n\n                # Log to wandb\n                if config.wandb_enabled and len(infos) > 0:\n                    log_payload = {\n                        \"loss\": avg_loss,\n                        \"learning_rate\": avg_lr,\n                        \"step\": global_step,\n                        \"time_per_step\": elapsed / config.log_interval,\n                    }\n                    if avg_grad_norm is not None:\n                        log_payload[\"grad_norm\"] = avg_grad_norm\n                    wandb.log(log_payload, step=global_step)\n\n                start_time = time.time()\n                infos = []  # Reset stats collection\n\n            global_step += 1\n            # Save checkpoint using the new mechanism\n            save_checkpoint(model, optim, global_step, config, is_main, data_config)\n\n            # Update progress bar\n            if pbar is not None:\n                pbar.update(1)\n                pbar.set_postfix(\n                    {\"loss\": f\"{loss.item():.4f}\", \"lr\": f\"{optim.param_groups[0]['lr']:.2e}\", \"step\": global_step}\n                )\n\n    # Close progress bar\n    if pbar is not None:\n        pbar.close()\n\n    # Finish wandb run\n    if is_main and config.wandb_enabled:\n        wandb.finish()\n\n    cleanup_ddp()\n\n\ndef main():\n    init_logging()\n    config = _config.cli()\n    train_loop(config)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "scripts/train_test.py",
    "content": "import dataclasses\nimport os\nimport pathlib\n\nimport pytest\n\nos.environ[\"JAX_PLATFORMS\"] = \"cpu\"\n\nfrom openpi.training import config as _config\n\nfrom . import train\n\n\n@pytest.mark.parametrize(\"config_name\", [\"debug\"])\ndef test_train(tmp_path: pathlib.Path, config_name: str):\n    config = dataclasses.replace(\n        _config._CONFIGS_DICT[config_name],  # noqa: SLF001\n        batch_size=2,\n        checkpoint_base_dir=str(tmp_path / \"checkpoint\"),\n        exp_name=\"test\",\n        overwrite=False,\n        resume=False,\n        num_train_steps=2,\n        log_interval=1,\n    )\n    train.main(config)\n\n    # test resuming\n    config = dataclasses.replace(config, resume=True, num_train_steps=4)\n    train.main(config)\n"
  },
  {
    "path": "src/openpi/__init__.py",
    "content": ""
  },
  {
    "path": "src/openpi/conftest.py",
    "content": "import os\n\nimport pynvml\nimport pytest\n\n\ndef set_jax_cpu_backend_if_no_gpu() -> None:\n    try:\n        pynvml.nvmlInit()\n        pynvml.nvmlShutdown()\n    except pynvml.NVMLError:\n        # No GPU found.\n        os.environ[\"JAX_PLATFORMS\"] = \"cpu\"\n\n\ndef pytest_configure(config: pytest.Config) -> None:\n    set_jax_cpu_backend_if_no_gpu()\n"
  },
  {
    "path": "src/openpi/models/__init__.py",
    "content": ""
  },
  {
    "path": "src/openpi/models/gemma.py",
    "content": "# Copyright 2024 Big Vision Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Gemma adaptation for Pi, taken from big_vision.\n\nWe follow this einsum axis naming convention:\n  B: batch\n  T: query length\n  S: k/v length\n  N: num query heads\n  K: num k/v heads\n  G: num query heads per k/v head\n  H: head dim\n  D: d_model (\"features\")\n\"\"\"\n\nfrom collections.abc import Sequence\nimport dataclasses\nfrom typing import Literal, TypeAlias\n\nimport einops\nimport flax.linen as nn\nimport jax\nimport jax.numpy as jnp\n\nimport openpi.models.lora as lora\nimport openpi.shared.array_typing as at\nimport openpi.training.sharding as sharding\n\nPALIGEMMA_VOCAB_SIZE = 257_152\n\n\n@dataclasses.dataclass\nclass Config:\n    width: int\n    depth: int\n    mlp_dim: int\n    num_heads: int\n    num_kv_heads: int\n    head_dim: int\n    lora_configs: dict[str, lora.LoRAConfig] = dataclasses.field(default_factory=dict)\n\n\nVariant = Literal[\"dummy\", \"gemma_300m\", \"gemma_300m_lora\", \"gemma_2b\", \"gemma_2b_lora\"]\n\n\ndef get_config(variant: Variant) -> Config:\n    \"\"\"Returns config for specified gemma variant.\"\"\"\n    if variant == \"dummy\":\n        return Config(\n            width=64,\n            depth=4,\n            mlp_dim=128,\n            num_heads=8,\n            num_kv_heads=1,\n            head_dim=16,\n        )\n    if variant == \"gemma_300m\":\n        # 311M params\n        return Config(\n            width=1024,\n            depth=18,\n            mlp_dim=4096,\n            num_heads=8,\n            num_kv_heads=1,\n            head_dim=256,\n        )\n    if variant == \"gemma_2b\":\n        return Config(\n            width=2048,\n            depth=18,\n            mlp_dim=16_384,\n            num_heads=8,\n            num_kv_heads=1,\n            head_dim=256,\n        )\n    if variant == \"gemma_2b_lora\":\n        return Config(\n            width=2048,\n            depth=18,\n            mlp_dim=16_384,\n            num_heads=8,\n            num_kv_heads=1,\n            head_dim=256,\n            lora_configs={\"attn\": lora.LoRAConfig(rank=16, alpha=16.0), \"ffn\": lora.LoRAConfig(rank=16, alpha=16.0)},\n        )\n    if variant == \"gemma_300m_lora\":\n        # 311M params\n        return Config(\n            width=1024,\n            depth=18,\n            mlp_dim=4096,\n            num_heads=8,\n            num_kv_heads=1,\n            head_dim=256,\n            lora_configs={\"attn\": lora.LoRAConfig(rank=32, alpha=32.0), \"ffn\": lora.LoRAConfig(rank=32, alpha=32.0)},\n        )\n    raise ValueError(f\"Unknown variant: {variant}\")\n\n\n@at.typecheck\nclass RMSNorm(nn.Module):\n    @nn.compact\n    def __call__(self, x, cond):\n        dtype = x.dtype  # original dtype, could be half-precision\n        var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True)  # compute variance in float32\n        normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06)))  # compute normalization in float32\n        if cond is None:\n            # regular RMSNorm\n            scale = self.param(\"scale\", nn.initializers.zeros_init(), (x.shape[-1]))\n            normed_inputs = normed_inputs * (\n                1 + scale\n            )  # scale by learned parameter in float32 (matches Flax implementation)\n            return normed_inputs.astype(dtype), None  # return in original dtype\n\n        # adaptive RMSNorm\n        modulation = nn.Dense(x.shape[-1] * 3, kernel_init=nn.initializers.zeros, dtype=dtype)(cond)\n        scale, shift, gate = jnp.split(modulation[:, None, :], 3, axis=-1)\n        normed_inputs = normed_inputs * (1 + scale) + shift  # scale and shift in float32\n        return normed_inputs.astype(dtype), gate\n\n\n@at.typecheck\nclass Embedder(nn.Module):\n    \"\"\"Embedder module.\"\"\"\n\n    vocab_size: int\n    embed_dim: int\n\n    def setup(self):\n        self.input_embedding_table = self.param(\n            \"input_embedding\",\n            nn.initializers.normal(),\n            (self.vocab_size, self.embed_dim),\n        )\n\n    def encode(self, x):\n        x = self.input_embedding_table[(x,)]\n        x *= jnp.sqrt(self.embed_dim).astype(x.dtype)\n        return x\n\n    def decode(self, x):\n        return jnp.dot(x, self.input_embedding_table.T)\n\n\n@at.typecheck\nclass Attention(nn.Module):\n    \"\"\"Attention module.\"\"\"\n\n    configs: Sequence[Config]\n\n    @nn.compact\n    def __call__(self, xs, positions, attn_mask, kv_cache):\n        # all experts must share the same head dim, num heads, and num kv heads for self-attention to work\n        assert all(config.head_dim == self.configs[0].head_dim for config in self.configs)\n        assert all(config.num_heads == self.configs[0].num_heads for config in self.configs)\n        assert all(config.num_kv_heads == self.configs[0].num_kv_heads for config in self.configs)\n\n        dtype = next(x.dtype for x in xs if x is not None)  # original dtype, could be half-precision\n\n        qkvs = []\n        for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):\n            if x is None:\n                continue\n            if config.num_kv_heads == config.num_heads:\n                qkv_einsum = lora.Einsum(\n                    shape=(3, config.num_heads, config.width, config.head_dim),\n                    name=_name(\"qkv_einsum\", i),\n                    init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),\n                    lora_config=config.lora_configs.get(\"attn\"),\n                )\n                qkvs.append(qkv_einsum(\"BSD,3KDH->3BSKH\", x))\n            else:\n                q_einsum = lora.Einsum(\n                    shape=(config.num_heads, config.width, config.head_dim),\n                    name=_name(\"q_einsum\", i),\n                    init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),\n                    lora_config=config.lora_configs.get(\"attn\"),\n                )\n                q = q_einsum(\"BTD,NDH->BTNH\", x)\n                kv_einsum = lora.Einsum(\n                    shape=(2, config.num_kv_heads, config.width, config.head_dim),\n                    name=_name(\"kv_einsum\", i),\n                    init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),\n                    lora_config=config.lora_configs.get(\"attn\"),\n                )\n                k, v = kv_einsum(\"BSD,2KDH->2BSKH\", x)\n                qkvs.append((q, k, v))\n\n        q, k, v = (jnp.concatenate(y, axis=1) for y in zip(*qkvs, strict=True))\n\n        q = _apply_rope(q, positions=positions)\n        q *= self.configs[0].head_dim ** -0.5\n\n        k = _apply_rope(k, positions=positions)\n\n        # should still be half-precision here (if input was half-precision)\n        assert q.dtype == k.dtype == v.dtype == dtype\n\n        if kv_cache is not None:\n            cache_k, cache_v = kv_cache\n            k = jnp.concatenate([cache_k, k], axis=1)\n            v = jnp.concatenate([cache_v, v], axis=1)\n\n        q = einops.rearrange(q, \"B T (K G) H -> B T K G H\", K=self.configs[0].num_kv_heads)\n        logits = jnp.einsum(\"BTKGH,BSKH->BKGTS\", q, k, preferred_element_type=jnp.float32)\n\n        if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):\n            raise ValueError(\n                f\"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}\"\n            )\n\n        # big_neg = jnp.finfo(logits.dtype).min\n        big_neg = -2.3819763e38  # See gemma/modules.py\n        masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)\n\n        probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)\n\n        encoded = jnp.einsum(\"BKGTS,BSKH->BTKGH\", probs, v)\n        encoded = einops.rearrange(encoded, \"B T K G H -> B T (K G) H\")\n\n        out = []\n        start = 0\n        for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):\n            if x is not None:\n                end = start + x.shape[1]\n                out_einsum = lora.Einsum(\n                    shape=(config.num_heads, config.head_dim, config.width),\n                    name=_name(\"attn_vec_einsum\", i),\n                    init_fn=nn.initializers.lecun_normal(in_axis=(-3, -2), out_axis=-1),\n                    lora_config=config.lora_configs.get(\"attn\"),\n                )\n                out.append(out_einsum(\"BTNH,NHD->BTD\", encoded[:, start:end]))\n                start = end\n            else:\n                out.append(None)\n\n        return out, (k, v)\n\n\n@at.typecheck\nclass FeedForward(nn.Module):\n    \"\"\"Feed forward module.\"\"\"\n\n    features: int\n    hidden_dim: int\n\n    @nn.compact\n    def __call__(self, x):\n        dtype = x.dtype  # original dtype, could be half-precision\n        w_gating = self.param(\n            \"gating_einsum\",\n            nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),\n            (2, self.features, self.hidden_dim),\n        ).astype(dtype)\n        ff_gate = jnp.dot(x, w_gating[0])\n        gate_value = nn.gelu(ff_gate)\n\n        ff1 = jnp.dot(x, w_gating[1])\n        activations = gate_value * ff1\n\n        w_linear = self.param(\n            \"linear\",\n            nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),\n            (self.hidden_dim, self.features),\n        ).astype(dtype)\n        outputs = jnp.dot(activations, w_linear)\n        assert outputs.dtype == dtype\n        return outputs\n\n\n@at.typecheck\nclass Block(nn.Module):\n    \"\"\"Transformer block.\"\"\"\n\n    configs: tuple[Config, ...]\n\n    dropout: float = 0.0\n    dropout_bdims: tuple[int, ...] = ()\n\n    @nn.compact\n    def __call__(self, xs, kv_cache, positions, attn_mask, adarms_cond, deterministic=True):  # noqa: FBT002\n        xs = sharding.activation_sharding_constraint(xs)\n        drop = nn.Dropout(self.dropout, self.dropout_bdims) if self.dropout else lambda x, _: x\n\n        attn = Attention(configs=self.configs, name=\"attn\")\n\n        pre_attn = []\n        gates = []\n        for i, x in enumerate(xs):\n            if x is not None:\n                x, gate = RMSNorm(name=_name(\"pre_attention_norm\", i))(x, adarms_cond[i])  # noqa: PLW2901\n            pre_attn.append(x)\n            gates.append(gate if x is not None else None)\n\n        pre_attn = sharding.activation_sharding_constraint(pre_attn)\n        post_attn, kv_cache = attn(pre_attn, positions, attn_mask, kv_cache)\n        post_attn = jax.tree.map(lambda x: drop(x, deterministic), post_attn)\n        post_attn = sharding.activation_sharding_constraint(post_attn)\n        xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, post_attn, gates, strict=True)]\n        xs = sharding.activation_sharding_constraint(xs)\n\n        out = []\n        gates = []\n        for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):\n            if x is not None:\n                x, gate = RMSNorm(name=_name(\"pre_ffw_norm\", i))(x, adarms_cond[i])  # noqa: PLW2901\n                x = lora.FeedForward(  # noqa: PLW2901\n                    features=config.width,\n                    hidden_dim=config.mlp_dim,\n                    name=_name(\"mlp\", i),\n                    lora_config=config.lora_configs.get(\"ffn\"),\n                )(x)\n            out.append(x)\n            gates.append(gate if x is not None else None)\n\n        out = sharding.activation_sharding_constraint(out)\n        out = jax.tree.map(lambda x: drop(x, deterministic), out)\n        xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, out, gates, strict=True)]\n        xs = sharding.activation_sharding_constraint(xs)\n\n        return xs, kv_cache\n\n\nKVCache: TypeAlias = tuple[at.Float[at.Array, \"l b _t _k _h\"], at.Float[at.Array, \"l b _t _v _h\"]]\n\n\n@at.typecheck\nclass Module(nn.Module):\n    \"\"\"Transformer model, supporting a mixture of different weights for different tokens.\"\"\"\n\n    configs: Sequence[Config]  # list of configs, one for each expert\n    embed_dtype: str\n\n    dropout: float = 0.0\n    dropout_bdims: tuple[int, ...] = ()  # Every float is dropped independently.\n    adarms: bool = False\n\n    def setup(self):\n        # all experts must have the same depth\n        assert all(config.depth == self.configs[0].depth for config in self.configs)\n\n        self.embedder = Embedder(\n            vocab_size=PALIGEMMA_VOCAB_SIZE,\n            embed_dim=self.configs[0].width,  # embedder for first expert only\n            name=\"embedder\",\n        )\n        block_cls = nn.remat(\n            Block,\n            prevent_cse=False,\n            static_argnums=(5,),  # 0=self, 6=deterministic\n            policy=jax.checkpoint_policies.nothing_saveable,\n        )\n        self.layers = nn.scan(\n            block_cls,\n            variable_axes={\"params\": 0},\n            split_rngs={\"params\": True, \"dropout\": True},\n            in_axes=(\n                0,\n                nn.broadcast,\n                nn.broadcast,\n                nn.broadcast,\n                nn.broadcast,\n            ),  # 0=kv_cache, 1=positions, 2=mask, 3=adarms_cond, 4=deterministic\n            length=self.configs[0].depth,\n        )(\n            configs=self.configs,\n            dropout=self.dropout,\n            dropout_bdims=self.dropout_bdims,\n        )\n        self.final_norms = [RMSNorm(name=_name(\"final_norm\", i)) for i in range(len(self.configs))]\n\n    @at.typecheck\n    def embed(self, tokens: at.Int[at.Array, \"b t\"]) -> at.Float[at.Array, \"b t d\"]:\n        return self.embedder.encode(tokens).astype(self.embed_dtype)\n\n    @at.typecheck\n    def __call__(\n        self,\n        # list of token arrays, one for each expert, or None if that expert should not be run\n        embedded: Sequence[at.Float[at.Array, \"b _t _d\"] | None],\n        positions: at.Int[at.Array, \"b t\"],\n        mask: at.Bool[at.Array, \"b t s\"],\n        adarms_cond: Sequence[at.Float[at.Array, \"b _d\"] | None] | None = None,\n        *,\n        kv_cache: KVCache | None = None,\n        deterministic: bool = True,\n    ) -> tuple[Sequence[at.Float[at.Array, \"b _t _d\"] | None], KVCache]:\n        embedded = jax.tree.map(lambda e: e.astype(self.embed_dtype), embedded)\n        mask = jnp.asarray(mask)[:, None, :, :]\n        if adarms_cond is None:\n            adarms_cond = [None] * len(self.configs)\n\n        embedded, kv_cache = self.layers(embedded, kv_cache, positions, mask, adarms_cond, deterministic)\n\n        assert all(e.dtype == jnp.dtype(self.embed_dtype) for e in embedded if e is not None)\n\n        return [\n            f(e, a)[0] if e is not None else e for f, e, a in zip(self.final_norms, embedded, adarms_cond, strict=True)\n        ], kv_cache\n\n    def init(self, use_adarms: Sequence[bool]):\n        \"\"\"Convenience method for initializing all parameters, necessary due to the quirks of linen.\"\"\"\n        self.embed(jnp.zeros((1, 1), dtype=jnp.int32))\n        self(\n            [jnp.zeros((1, 1, c.width)) for c in self.configs],\n            jnp.zeros((1, len(self.configs)), dtype=jnp.int32),\n            jnp.zeros((1, len(self.configs), len(self.configs)), dtype=bool),\n            adarms_cond=[jnp.zeros((1, c.width)) if u else None for u, c in zip(use_adarms, self.configs, strict=True)],\n        )\n\n\ndef _apply_rope(x, *, positions, max_wavelength=10_000):\n    \"\"\"Applies RoPE positions [B, L] to x [B, L, H, D].\"\"\"\n    freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)\n    timescale = max_wavelength**freq_exponents\n    radians = positions[..., None] / timescale[None, None, :]\n    radians = radians[..., None, :]\n    assert radians.dtype == jnp.float32\n    # radians.shape = [...,L,1,d=D/2]\n    sin, cos = jnp.sin(radians), jnp.cos(radians)\n    x1, x2 = jnp.split(x, 2, axis=-1)\n    res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)\n    assert res.dtype == jnp.float32\n    # The original bigvision impl allows RoPE to upcast to float32. It is then immediately downcast again to the cache\n    # dtype when in inference mode (but not in training mode). I don't think any of this was intentional. Based on the\n    # original DeepMind impl, as well as the widely-used transformers impl, it is ok to always downcast back to bfloat16\n    # here.\n    return res.astype(x.dtype)\n\n\ndef _name(name, i):\n    # we name layers like this because we want the first expert's weights to have no suffix (e.g., \"attn\"), so that they\n    # can be loaded seamlessly from the existing PaliGemma checkpoint. subsequent experts will have a suffix (e.g.,\n    # \"attn_1\") and their weights will be initialized from scratch. in practice, we only use two experts -- PaliGemma,\n    # and the action expert.\n    if i == 0:\n        return name\n    return f\"{name}_{i}\"\n\n\ndef _gated_residual(x, y, gate):\n    assert (x is None) == (y is None)\n    if x is None:\n        return None\n    if gate is None:\n        return x + y\n    return x + y * gate\n"
  },
  {
    "path": "src/openpi/models/gemma_fast.py",
    "content": "# Copyright 2024 Big Vision Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nGemma model implementation from big_vision/models/ppp/gemma.py (with small modifications for NNX compatibility)\nUsed for FAST autoregressive policies.\n\"\"\"\n\nimport dataclasses\nfrom typing import Literal, TypeAlias\n\nimport einops\nimport flax.linen as nn\nimport jax\nimport jax.numpy as jnp\nimport ml_collections\n\nimport openpi.models.lora as lora\nimport openpi.shared.array_typing as at\n\nVariant = Literal[\"gemma_2b\", \"gemma_2b_lora\"]\n\n\ndef get_config(variant):\n    \"\"\"Returns config for specified gemma variant.\"\"\"\n    if variant == \"gemma_2b\":\n        return ml_collections.ConfigDict(\n            {\n                \"variant\": variant,\n                \"width\": 2048,\n                \"depth\": 18,\n                \"mlp_dim\": 16_384,\n                \"num_heads\": 8,\n                \"num_kv_heads\": 1,\n                \"head_dim\": 256,\n                \"norm_eps\": 1e-6,\n                \"vocab_size\": 257_152,\n                \"scan\": True,\n                \"remat_policy\": \"nothing_saveable\",\n            }\n        )\n    if variant == \"gemma_2b_lora\":\n        return ml_collections.ConfigDict(\n            {\n                \"variant\": variant,\n                \"width\": 2048,\n                \"depth\": 18,\n                \"mlp_dim\": 16_384,\n                \"num_heads\": 8,\n                \"num_kv_heads\": 1,\n                \"head_dim\": 256,\n                \"norm_eps\": 1e-6,\n                \"vocab_size\": 257_152,\n                \"scan\": True,\n                \"remat_policy\": \"nothing_saveable\",\n                \"lora_configs\": {\n                    \"attn\": lora.LoRAConfig(rank=16, alpha=16.0),\n                    \"ffn\": lora.LoRAConfig(rank=16, alpha=16.0),\n                },\n            }\n        )\n    raise ValueError(f\"Unknown variant: {variant}\")\n\n\n@at.typecheck\nclass Einsum(nn.Module):\n    shape: tuple[int, ...]\n\n    @nn.compact\n    def __call__(self, eqn, x):\n        dtype = x.dtype  # original dtype, could be half-precision\n        w = self.param(\"w\", nn.initializers.zeros_init(), self.shape).astype(dtype)\n        return jnp.einsum(eqn, x, w)\n\n\n@at.typecheck\nclass RMSNorm(nn.Module):\n    @nn.compact\n    def __call__(self, x):\n        dtype = x.dtype  # original dtype, could be half-precision\n        scale = self.param(\"scale\", nn.initializers.zeros_init(), (x.shape[-1]))\n        var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True)  # compute variance in float32\n        normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06)))  # compute normalization in float32\n        normed_inputs = normed_inputs * (\n            1 + scale\n        )  # scale by learned parameter in float32 (matches Flax implementation)\n        return normed_inputs.astype(dtype)  # return in original dtype\n\n\n@at.typecheck\nclass Embedder(nn.Module):\n    \"\"\"Embedder module.\"\"\"\n\n    vocab_size: int\n    embed_dim: int\n\n    def setup(self):\n        self.input_embedding_table = self.param(\n            \"input_embedding\",\n            nn.initializers.zeros_init(),\n            (self.vocab_size, self.embed_dim),\n        )\n\n    def encode(self, x):\n        x = self.input_embedding_table[(x,)]\n        x *= jnp.sqrt(self.embed_dim).astype(x.dtype)\n        return x\n\n    def decode(self, x):\n        return jnp.dot(x, self.input_embedding_table.T)\n\n\n@at.typecheck\nclass Attention(nn.Module):\n    \"\"\"Attention module.\"\"\"\n\n    num_heads: int\n    num_kv_heads: int\n    features: int\n    head_dim: int\n\n    cache_dtype: str | None = None\n\n    lora_config: lora.LoRAConfig | None = None\n\n    def setup(self):\n        if self.num_kv_heads == self.num_heads:\n            self.qkv_einsum = lora.Einsum(\n                shape=(3, self.num_heads, self.features, self.head_dim),\n                name=\"qkv_einsum\",\n                init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),\n                lora_config=self.lora_config,\n            )\n        else:\n            self.q_einsum = lora.Einsum(\n                shape=(self.num_heads, self.features, self.head_dim),\n                name=\"q_einsum\",\n                init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),\n                lora_config=self.lora_config,\n            )\n            self.kv_einsum = lora.Einsum(\n                shape=(2, self.num_kv_heads, self.features, self.head_dim),\n                name=\"kv_einsum\",\n                init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),\n                lora_config=self.lora_config,\n            )\n        self.attn_vec_einsum = lora.Einsum(\n            shape=(self.num_heads, self.head_dim, self.features),\n            name=\"attn_vec_einsum\",\n            init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),\n            lora_config=self.lora_config,\n        )\n\n    def _init_cache(self, k, v, cache_size):\n        \"\"\"Initialize KV cache\"\"\"\n        prefill_len = k.shape[1]\n        pad_width = ((0, 0), (0, cache_size - prefill_len), (0, 0), (0, 0))\n        cache_dtype = self.cache_dtype or k.dtype\n        k_cache = jnp.pad(k.astype(cache_dtype), pad_width)\n        v_cache = jnp.pad(v.astype(cache_dtype), pad_width)\n        idx = jnp.zeros((k.shape[0],), dtype=jnp.int32) + prefill_len\n        return idx, k_cache, v_cache\n\n    def _update_cache(self, k, v, idx, k_cache, v_cache):\n        \"\"\"Update KV cache with new values\"\"\"\n        assert k.shape[1] == 1, \"Only support kv-cache updates of length 1\"\n        indices = (0, idx[0], 0, 0)\n        cache_dtype = self.cache_dtype or k.dtype\n        k_new = jax.lax.dynamic_update_slice(k_cache, k.astype(cache_dtype), indices)\n        v_new = jax.lax.dynamic_update_slice(v_cache, v.astype(cache_dtype), indices)\n        idx_new = idx + 1\n        return idx_new, k_new, v_new\n\n    @nn.compact\n    def __call__(self, x, positions, attn_mask, kv_cache, decode, deterministic=True):  # noqa: FBT002\n        dtype = x.dtype  # original dtype, could be half-precision\n        if self.num_kv_heads == self.num_heads:\n            q, k, v = self.qkv_einsum(\"BSD,3KDH->3BSKH\", x)\n        else:\n            q = self.q_einsum(\"BTD,NDH->BTNH\", x)\n            k, v = self.kv_einsum(\"BSD,2KDH->2BSKH\", x)\n\n        q = _apply_rope(q, positions=positions)  # promotes to float32\n        q *= self.head_dim**-0.5\n\n        k = _apply_rope(k, positions=positions)  # promotes to float32\n\n        if kv_cache is None:\n            idx, k_cache, v_cache = self._init_cache(k, v, attn_mask.shape[-1])\n        else:\n            idx, k_cache, v_cache = kv_cache\n            idx, k_cache, v_cache = self._update_cache(k, v, idx, k_cache, v_cache)\n\n        k, v = k_cache, v_cache\n        kv_cache = (idx, k_cache, v_cache)\n\n        q = einops.rearrange(q, \"B T (K G) H -> B T K G H\", K=self.num_kv_heads)\n        logits = jnp.einsum(\"BTKGH,BSKH->BKGTS\", q, k, preferred_element_type=jnp.float32)\n\n        if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):\n            raise ValueError(\n                f\"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}\"\n            )\n\n        # big_neg = jnp.finfo(logits.dtype).min\n        big_neg = -2.3819763e38  # See gemma/modules.py\n        masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)\n\n        probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)\n\n        encoded = jnp.einsum(\"BKGTS,BSKH->BTKGH\", probs, v)\n        encoded = einops.rearrange(encoded, \"B T K G H -> B T (K G) H\")\n        return self.attn_vec_einsum(\"BTNH,NHD->BTD\", encoded), kv_cache\n\n\n@at.typecheck\nclass Block(nn.Module):\n    \"\"\"Transformer block.\"\"\"\n\n    num_heads: int\n    num_kv_heads: int\n    embed_dim: int\n    head_dim: int\n    hidden_dim: int\n\n    dropout: float = 0.0\n    dropout_bdims: tuple[int, ...] = ()\n    cache_dtype: str | None = None\n    lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)\n\n    def setup(self):\n        self.pre_attention_norm = RMSNorm()\n        self.attn = Attention(\n            num_heads=self.num_heads,\n            num_kv_heads=self.num_kv_heads,\n            features=self.embed_dim,\n            head_dim=self.head_dim,\n            cache_dtype=self.cache_dtype,\n            lora_config=self.lora_configs.get(\"attn\"),\n        )\n        self.pre_ffw_norm = RMSNorm()\n        self.mlp = lora.FeedForward(\n            features=self.embed_dim, hidden_dim=self.hidden_dim, name=\"mlp\", lora_config=self.lora_configs.get(\"ffn\")\n        )\n        if self.dropout:\n            self.drop = nn.Dropout(self.dropout, self.dropout_bdims)\n        else:\n            self.drop = lambda x, _: x\n\n    def __call__(self, x, kv_cache, positions, attn_mask, decode, deterministic=True):  # noqa: FBT002\n        x = nn.with_logical_constraint(x, (\"act_batch\", \"act_len\", \"act_emb\"))\n        inputs_normalized = self.pre_attention_norm(x)\n        attn_output, kv_cache = self.attn(inputs_normalized, positions, attn_mask, kv_cache, decode, deterministic)\n        attn_output = self.drop(attn_output, deterministic)\n        attn_output += x\n        residual = attn_output\n        attn_output = self.pre_ffw_norm(attn_output)\n        outputs = self.mlp(attn_output)\n        outputs = self.drop(outputs, deterministic)\n        outputs = residual + outputs\n        return outputs, kv_cache\n\n\nKVCache: TypeAlias = tuple[at.Int[at.Array, \" b\"], at.Float[at.Array, \"b _t _k _h\"], at.Float[at.Array, \"b _t _v _h\"]]\n\n\n@at.typecheck\nclass Module(nn.Module):\n    \"\"\"gemma model.\"\"\"\n\n    variant: str\n\n    width: int\n    depth: int\n    mlp_dim: int\n    num_heads: int\n    num_kv_heads: int\n    head_dim: int\n    norm_eps: float\n    vocab_size: int\n    embed_dtype: str\n\n    dropout: float = 0.0\n    dropout_bdims: tuple[int, ...] = ()  # Every float is dropped independently.\n    cache_dtype: str | None = None\n\n    scan: bool = False\n    remat_policy: str = \"none\"\n    lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)\n\n    @nn.compact\n    def __call__(\n        self,\n        tokens=None,\n        embedded_prefix=None,\n        embed_only=False,  # noqa: FBT002\n        pre_logits=None,\n        positions=None,\n        mask=None,\n        decode=False,  # noqa: FBT002\n        kv_cache=None,\n        deterministic=True,  # noqa: FBT002\n        return_prelogits=False,  # noqa: FBT002\n    ):\n        \"\"\"Embed only, or complete forward pass.\n\n        Args:\n          tokens: Embedded, then and appended to `embedded_prefix`. Can be None.\n          embedded_prefix: Optional prefix that is already embedded.\n          embed_only: Whether to compute embeddings only.\n          pre_logits: If present computes logits from pre_logits and returns.\n          positions: Optional `[B, T]` allows to specify the absolute position of\n            the tokens.\n          mask: Optional attention mask `[B, T, S]`.\n          decode: Whether to use kv-cache. Caller must pass masks and positions.\n          deterministic: Forwarded to all dropout layers.\n          return_prelogits: Whether to return the pre-logits.\n\n        Returns:\n          If `embed_only=False`, then `(logits, out)` will be returned.\n          If `embed_only=True`, then the embeddings will be returned.\n          If `return_prelogits=True`, then the pre-logits will be returned.\n        \"\"\"\n        out = {}\n\n        embedder = Embedder(vocab_size=self.vocab_size, embed_dim=self.width, name=\"embedder\")\n\n        if pre_logits is not None:\n            x = out[\"pre_logits\"] = pre_logits\n            logits = out[\"logits\"] = embedder.decode(x)\n            return logits, out\n\n        x = []\n        if embedded_prefix is not None:\n            x.append(embedded_prefix)\n        if tokens is not None:\n            x.append(embedder.encode(tokens))\n\n        x = jnp.concatenate(x, axis=-2)\n        x = x.astype(self.embed_dtype)\n        batch_size, seq_len, width = x.shape\n\n        if embed_only:\n            return x\n\n        if decode:\n            assert positions is not None and mask is not None, (  # noqa: PT018\n                \"Must explicitly pass positions and mask for decoding.\"\n            )\n\n        if positions is None:\n            positions = jnp.arange(seq_len).astype(jnp.int32)[None, :]\n        assert positions.shape[1] == x.shape[1], (positions.shape, x.shape)\n\n        if mask is None:\n            mask = nn.attention.make_causal_mask(jnp.ones([batch_size, seq_len]))\n        if mask.ndim == 3:\n            mask = mask[:, None, :, :]\n        cache_size = max(seq_len, mask.shape[-1])\n        assert mask.shape == (batch_size, 1, seq_len, cache_size), mask.shape\n\n        if self.remat_policy == \"none\":\n            block_cls = Block\n        else:\n            block_cls = nn.remat(\n                Block,\n                prevent_cse=not self.scan,\n                static_argnums=(5, 6),  # 0=self, 5=decode, 6=deterministic\n                policy=getattr(jax.checkpoint_policies, self.remat_policy),\n            )\n\n        block_kw = {\n            \"num_heads\": self.num_heads,\n            \"head_dim\": self.head_dim,\n            \"num_kv_heads\": self.num_kv_heads,\n            \"embed_dim\": width,\n            \"hidden_dim\": self.mlp_dim,\n            \"dropout\": self.dropout,\n            \"dropout_bdims\": self.dropout_bdims,\n            \"cache_dtype\": self.cache_dtype,\n            \"lora_configs\": self.lora_configs,\n        }\n        layers = self.scope.push(\"layers\")\n        blocks = [\n            nn.scan(\n                block_cls,\n                variable_axes={\"params\": 0},\n                split_rngs={\"params\": True, \"dropout\": True},\n                in_axes=(0, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast),  # 0=kv_cache, 1=positions, 2=mask\n                length=self.depth,\n            )(parent=layers, **block_kw)\n        ]\n        for block in blocks:\n            x, kv_cache = block(x, kv_cache, positions, mask, decode, deterministic)\n\n        assert x.dtype == jnp.dtype(self.embed_dtype)  # Sanity check.\n        out[\"encoded\"] = x\n\n        x = RMSNorm(name=\"final_norm\")(x)\n        out[\"pre_logits\"] = x\n        if return_prelogits:\n            return x, kv_cache, out\n\n        x = embedder.decode(x)\n        out[\"logits\"] = x\n\n        return x, kv_cache, out\n\n    def init(self):\n        \"\"\"Convenience method for initializing all parameters, necessary due to the quirks of linen.\"\"\"\n        self(jnp.zeros((1, 1), dtype=jnp.int32))\n\n\ndef _apply_rope(x, *, positions, max_wavelength=10_000):\n    \"\"\"Applies RoPE positions [B, L] to x [B, L, H, D].\"\"\"\n    freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)\n    timescale = max_wavelength**freq_exponents\n    radians = positions[..., None] / timescale[None, None, :]\n    radians = radians[..., None, :]\n    assert radians.dtype == jnp.float32\n    # radians.shape = [...,L,1,d=D/2]\n    sin, cos = jnp.sin(radians), jnp.cos(radians)\n    x1, x2 = jnp.split(x, 2, axis=-1)\n    res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)\n    assert res.dtype == jnp.float32\n    return res\n"
  },
  {
    "path": "src/openpi/models/lora.py",
    "content": "import math\nimport re\n\nimport flax.linen as nn\nimport flax.struct as struct\nimport jax.numpy as jnp\n\nimport openpi.shared.array_typing as at\n\n\n@struct.dataclass\nclass LoRAConfig:\n    \"\"\"Configuration for LoRA.\"\"\"\n\n    # LoRA rank.\n    rank: int\n    # LoRA scaling factor.\n    alpha: float = 1.0\n    # Initialization function for LoRA parameters.\n    init_fn: nn.initializers.Initializer = nn.initializers.normal(stddev=0.01)\n    # Enable rank-stabilized LoRA: https://arxiv.org/pdf/2312.03732\n    rslora: bool = False\n    # Axes in the weight to apply LoRA to. Should typically be the last two axes.\n    axes: tuple[int, int] = (-2, -1)\n    # Axis label which is used by LoRA in einsum equations. Must not be present in the original equation.\n    label: str = \"L\"\n\n    @property\n    def scaling_value(self) -> float:\n        return self.alpha / math.sqrt(self.rank) if self.rslora else self.alpha / self.rank\n\n\nclass Einsum(nn.Module):\n    \"\"\"Einsum with LoRA support. Can be used as a drop-in replacement for the Gemma Einsum.\"\"\"\n\n    # Shape of the weight.\n    shape: tuple[int, ...]\n    # Initialization function for the weight.\n    init_fn: nn.initializers.Initializer = nn.initializers.zeros\n    # If not None, apply LoRA to the weight.\n    lora_config: LoRAConfig | None = None\n\n    def setup(self):\n        self.w = self.param(\"w\", self.init_fn, self.shape)\n\n        if config := self.lora_config:\n            # Setup LoRA parameters.\n            shape_a, shape_b = list(self.shape), list(self.shape)\n            shape_a[config.axes[1]] = config.rank\n            shape_b[config.axes[0]] = config.rank\n            self.w_a = self.param(\"lora_a\", config.init_fn, shape_a)\n            self.w_b = self.param(\"lora_b\", config.init_fn, shape_b)\n\n    @nn.compact\n    def __call__(self, eqn: str, x):\n        dtype = x.dtype  # original dtype, could be half-precision\n        result = jnp.einsum(eqn, x, self.w.astype(dtype))\n\n        if config := self.lora_config:\n            eqn_a, eqn_b = self._make_lora_eqns(eqn)\n            lora = jnp.einsum(eqn_a, x, self.w_a.astype(dtype))\n            lora = jnp.einsum(eqn_b, lora, self.w_b.astype(dtype))\n            result = result + lora * config.scaling_value\n\n        return result\n\n    def _make_lora_eqns(self, eqn: str) -> tuple[str, str]:\n        if \"L\" in eqn:\n            raise ValueError(f\"L already in eqn: {eqn}\")\n        if not (m := re.match(\"(.*),(.*)->(.*)\", eqn)):\n            raise ValueError(f\"Unsupported einsum eqn: {eqn}\")\n        lhs, rhs, out = m.groups()\n\n        assert self.lora_config is not None\n        a_label, b_label = (rhs[x] for x in self.lora_config.axes)\n        label = self.lora_config.label\n\n        a_rhs = rhs.replace(b_label, label)\n        a_out = out.replace(b_label, label)\n        eqn_a = f\"{lhs},{a_rhs}->{a_out}\"\n\n        b_rhs = rhs.replace(a_label, label)\n        eqn_b = f\"{a_out},{b_rhs}->{out}\"\n\n        return eqn_a, eqn_b\n\n\nclass FeedForward(nn.Module):\n    \"\"\"Feed forward module.\"\"\"\n\n    features: int\n    hidden_dim: int\n    # If not None, apply LoRA to the weight.\n    lora_config: LoRAConfig | None = None\n\n    def setup(self):\n        self.w_gating = self.param(\n            \"gating_einsum\",\n            nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),\n            (2, self.features, self.hidden_dim),\n        )\n        self.w_linear = self.param(\n            \"linear\",\n            nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),\n            (self.hidden_dim, self.features),\n        )\n        self.w_gating_lora = None\n        self.w_linear_lora = None\n        if self.lora_config:\n            # Setup LoRA parameters.\n            # TODO: follow up with a simplified init_fn api.\n            self.w_gating_lora = (\n                self.param(\"gating_einsum_lora_a\", self.lora_config.init_fn, (2, self.features, self.lora_config.rank)),\n                self.param(\n                    \"gating_einsum_lora_b\", self.lora_config.init_fn, (2, self.lora_config.rank, self.hidden_dim)\n                ),\n            )\n            self.w_linear_lora = (\n                self.param(\"linear_lora_a\", self.lora_config.init_fn, (self.hidden_dim, self.lora_config.rank)),\n                self.param(\"linear_lora_b\", self.lora_config.init_fn, (self.lora_config.rank, self.features)),\n            )\n\n    @nn.compact\n    def __call__(self, x):\n        dtype = x.dtype  # original dtype, could be half-precision\n        ff_gate = self._dot(\n            x,\n            self.w_gating[0],\n            None if self.w_gating_lora is None else (self.w_gating_lora[0][0], self.w_gating_lora[1][0]),\n        )\n        gate_value = nn.gelu(ff_gate)\n\n        ff1 = self._dot(\n            x,\n            self.w_gating[1],\n            None if self.w_gating_lora is None else (self.w_gating_lora[0][1], self.w_gating_lora[1][1]),\n        )\n        activations = gate_value * ff1\n\n        outputs = self._dot(activations, self.w_linear, self.w_linear_lora)\n        assert outputs.dtype == dtype\n        return outputs\n\n    def _dot(self, x: at.Array, w: at.Array, lora_weights: tuple[at.Array, at.Array] | None) -> at.Array:\n        base = jnp.dot(x, w.astype(x.dtype))\n        if lora_weights is None:\n            return base\n        return base + jnp.dot(jnp.dot(x, lora_weights[0].astype(x.dtype)), lora_weights[1].astype(x.dtype))\n"
  },
  {
    "path": "src/openpi/models/lora_test.py",
    "content": "import flax.linen as nn\nimport jax\nimport jax.numpy as jnp\n\nimport openpi.models.lora as lora\n\n\ndef test_lora_einsum_params_shape():\n    shape = (3, 8, 32, 4)  # (3KDH)\n    einsum = lora.Einsum(shape)\n    lora0 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2))\n    lora1 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, axes=(1, 2)))\n\n    key = jax.random.key(0)\n    x = jax.random.normal(key, (8, 64, 32))  # (BSD)\n    eqn = \"BSD,3KDH->3BSKH\"\n\n    # Ensure that lora parameters are not initialized when LoRA is not used.\n    params = einsum.init(key, eqn, x)\n    assert \"lora_a\" not in params[\"params\"]\n    assert \"lora_b\" not in params[\"params\"]\n\n    # Check that default axes work.\n    params_lora0 = lora0.init(key, eqn, x)\n    assert params_lora0[\"params\"][\"lora_a\"].shape == (3, 8, 32, 2)\n    assert params_lora0[\"params\"][\"lora_b\"].shape == (3, 8, 2, 4)\n\n    # Check that user provided axes work.\n    params_lora1 = lora1.init(key, eqn, x)\n    assert params_lora1[\"params\"][\"lora_a\"].shape == (3, 8, 2, 4)\n    assert params_lora1[\"params\"][\"lora_b\"].shape == (3, 2, 32, 4)\n\n\ndef test_lora_einsum_same_output():\n    shape = (3, 8, 32, 4)  # (3KDH)\n    einsum = lora.Einsum(shape)\n    einsum_lora = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros))\n\n    key = jax.random.key(0)\n    x = jax.random.normal(key, (8, 64, 32))  # (BSD)\n    eqn = \"BSD,3KDH->3BSKH\"\n\n    params = einsum.init(key, eqn, x)\n    output = einsum.apply(params, eqn, x)\n\n    params_lora = einsum_lora.init(key, eqn, x)\n    output_lora = einsum_lora.apply(params_lora, eqn, x)\n\n    # Results are the same since the LoRA parameters are initialized to zeros.\n    assert jnp.allclose(output, output_lora)\n\n\ndef test_lora_ffn_params_shape():\n    ffn = lora.FeedForward(features=8, hidden_dim=32)\n    ffn_lora = lora.FeedForward(\n        features=8,\n        hidden_dim=32,\n        lora_config=lora.LoRAConfig(rank=2),\n    )\n\n    key = jax.random.key(0)\n    x = jax.random.normal(key, (2, 8))\n\n    params = ffn.init(key, x)\n    assert params[\"params\"][\"gating_einsum\"].shape == (2, 8, 32)\n    assert params[\"params\"][\"linear\"].shape == (32, 8)\n\n    params_lora = ffn_lora.init(key, x)\n    assert params_lora[\"params\"][\"gating_einsum\"].shape == (2, 8, 32)\n    assert params_lora[\"params\"][\"linear\"].shape == (32, 8)\n    assert params_lora[\"params\"][\"gating_einsum_lora_a\"].shape == (2, 8, 2)\n    assert params_lora[\"params\"][\"gating_einsum_lora_b\"].shape == (2, 2, 32)\n    assert params_lora[\"params\"][\"linear_lora_a\"].shape == (32, 2)\n    assert params_lora[\"params\"][\"linear_lora_b\"].shape == (2, 8)\n\n\ndef test_lora_ffn_same_output():\n    ffn = lora.FeedForward(features=8, hidden_dim=32)\n    ffn_lora = lora.FeedForward(\n        features=8,\n        hidden_dim=32,\n        lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros),\n    )\n\n    key = jax.random.key(0)\n    x = jax.random.normal(key, (2, 8))\n\n    params = ffn.init(key, x)\n    output = ffn.apply(params, x)\n\n    params_lora = ffn_lora.init(key, x)\n    output_lora = ffn_lora.apply(params_lora, x)\n\n    assert jnp.allclose(output, output_lora)\n"
  },
  {
    "path": "src/openpi/models/model.py",
    "content": "import abc\nfrom collections.abc import Sequence\nimport dataclasses\nimport enum\nimport logging\nimport pathlib\nfrom typing import Generic, TypeVar\n\nimport augmax\nfrom flax import nnx\nfrom flax import struct\nfrom flax import traverse_util\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport orbax.checkpoint as ocp\nimport safetensors\nimport torch\n\nfrom openpi.models_pytorch import pi0_pytorch\nfrom openpi.shared import image_tools\nimport openpi.shared.array_typing as at\n\nlogger = logging.getLogger(\"openpi\")\n\n# Type variable for array types (JAX arrays, PyTorch tensors, or numpy arrays)\nArrayT = TypeVar(\"ArrayT\", bound=jax.Array | torch.Tensor | np.ndarray)\n\n\nclass ModelType(enum.Enum):\n    \"\"\"Supported model types.\"\"\"\n\n    PI0 = \"pi0\"\n    PI0_FAST = \"pi0_fast\"\n    PI05 = \"pi05\"\n\n\n# The model always expects these images\nIMAGE_KEYS = (\n    \"base_0_rgb\",\n    \"left_wrist_0_rgb\",\n    \"right_wrist_0_rgb\",\n)\n\n\n# This may need change if we release a small model.\nIMAGE_RESOLUTION = (224, 224)\n\n\n# Data format\n#\n# Data transforms produce the model input as a nested dictionary which is later converted\n# into `Obesrvation` and `Actions` objects. See below.\n#\n# In the dictory form, this data should look like:\n# {\n#     # Observation data.\n#     \"image\": {\n#         \"base_0_rgb\": (float32|uint8)[*b, h, w, 3],  # RGB image in [-1, 1] or [0, 255]\n#         ...  # Additional camera views\n#     },\n#     \"image_mask\": {\n#         \"base_0_rgb\": bool[*b],  # True if image is valid\n#         ...  # Masks for additional views\n#     },\n#     \"state\": float32[*b, s],  # Low-dimensional robot state\n#     \"tokenized_prompt\": int32[*b, l],  # Optional, tokenized language prompt\n#     \"tokenized_prompt_mask\": bool[*b, l],  # Optional, mask for tokenized prompt\n#     \"token_ar_mask\": int32[*b, l],  # Optional, autoregressive mask for FAST model\n#     \"token_loss_mask\": bool[*b, l],  # Optional, loss mask for FAST model\n#\n#      # Actions data.\n#      \"actions\": float32[*b ah ad]\n# }\n# where:\n#   *b = batch dimensions\n#   h,w = image height/width\n#   s = state dimension\n#   l = sequence length\n#\n@at.typecheck\n@struct.dataclass\nclass Observation(Generic[ArrayT]):\n    \"\"\"Holds observations, i.e., inputs to the model.\n\n    See `Observation.from_dict` to see the expected dictionary form. This is the format\n    that should be produced by the data transforms.\n    \"\"\"\n\n    # Images, in [-1, 1] float32.\n    images: dict[str, at.Float[ArrayT, \"*b h w c\"]]\n    # Image masks, with same keys as images.\n    image_masks: dict[str, at.Bool[ArrayT, \"*b\"]]\n    # Low-dimensional robot state.\n    state: at.Float[ArrayT, \"*b s\"]\n\n    # Tokenized prompt.\n    tokenized_prompt: at.Int[ArrayT, \"*b l\"] | None = None\n    # Tokenized prompt mask.\n    tokenized_prompt_mask: at.Bool[ArrayT, \"*b l\"] | None = None\n\n    # pi0-fast model specific fields.\n\n    # Token auto-regressive mask (for FAST autoregressive model).\n    token_ar_mask: at.Int[ArrayT, \"*b l\"] | None = None\n    # Token loss mask (for FAST autoregressive model).\n    token_loss_mask: at.Bool[ArrayT, \"*b l\"] | None = None\n\n    @classmethod\n    def from_dict(cls, data: at.PyTree[ArrayT]) -> \"Observation[ArrayT]\":\n        \"\"\"This method defines the mapping between unstructured data (i.e., nested dict) to the structured Observation format.\"\"\"\n        # Ensure that tokenized_prompt and tokenized_prompt_mask are provided together.\n        if (\"tokenized_prompt\" in data) != (\"tokenized_prompt_mask\" in data):\n            raise ValueError(\"tokenized_prompt and tokenized_prompt_mask must be provided together.\")\n        # If images are uint8, convert them to [-1, 1] float32.\n        for key in data[\"image\"]:\n            if data[\"image\"][key].dtype == np.uint8:\n                data[\"image\"][key] = data[\"image\"][key].astype(np.float32) / 255.0 * 2.0 - 1.0\n            elif hasattr(data[\"image\"][key], \"dtype\") and data[\"image\"][key].dtype == torch.uint8:\n                data[\"image\"][key] = data[\"image\"][key].to(torch.float32).permute(0, 3, 1, 2) / 255.0 * 2.0 - 1.0\n        return cls(\n            images=data[\"image\"],\n            image_masks=data[\"image_mask\"],\n            state=data[\"state\"],\n            tokenized_prompt=data.get(\"tokenized_prompt\"),\n            tokenized_prompt_mask=data.get(\"tokenized_prompt_mask\"),\n            token_ar_mask=data.get(\"token_ar_mask\"),\n            token_loss_mask=data.get(\"token_loss_mask\"),\n        )\n\n    def to_dict(self) -> at.PyTree[ArrayT]:\n        \"\"\"Convert the Observation to a nested dict.\"\"\"\n        result = dataclasses.asdict(self)\n        result[\"image\"] = result.pop(\"images\")\n        result[\"image_mask\"] = result.pop(\"image_masks\")\n        return result\n\n\n# Defines the format of the actions. This field is included as \"actions\" inside the dictionary\n# produced by the data transforms.\nActions = at.Float[ArrayT, \"*b ah ad\"]\n\n\ndef preprocess_observation(\n    rng: at.KeyArrayLike | None,\n    observation: Observation,\n    *,\n    train: bool = False,\n    image_keys: Sequence[str] = IMAGE_KEYS,\n    image_resolution: tuple[int, int] = IMAGE_RESOLUTION,\n) -> Observation:\n    \"\"\"Preprocess the observations by performing image augmentations (if train=True), resizing (if necessary), and\n    filling in a default image mask (if necessary).\n    \"\"\"\n\n    if not set(image_keys).issubset(observation.images):\n        raise ValueError(f\"images dict missing keys: expected {image_keys}, got {list(observation.images)}\")\n\n    batch_shape = observation.state.shape[:-1]\n\n    out_images = {}\n    for key in image_keys:\n        image = observation.images[key]\n        if image.shape[1:3] != image_resolution:\n            logger.info(f\"Resizing image {key} from {image.shape[1:3]} to {image_resolution}\")\n            image = image_tools.resize_with_pad(image, *image_resolution)\n\n        if train:\n            # Convert from [-1, 1] to [0, 1] for augmax.\n            image = image / 2.0 + 0.5\n\n            transforms = []\n            if \"wrist\" not in key:\n                height, width = image.shape[1:3]\n                transforms += [\n                    augmax.RandomCrop(int(width * 0.95), int(height * 0.95)),\n                    augmax.Resize(width, height),\n                    augmax.Rotate((-5, 5)),\n                ]\n            transforms += [\n                augmax.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5),\n            ]\n            sub_rngs = jax.random.split(rng, image.shape[0])\n            image = jax.vmap(augmax.Chain(*transforms))(sub_rngs, image)\n\n            # Back to [-1, 1].\n            image = image * 2.0 - 1.0\n\n        out_images[key] = image\n\n    # obtain mask\n    out_masks = {}\n    for key in out_images:\n        if key not in observation.image_masks:\n            # do not mask by default\n            out_masks[key] = jnp.ones(batch_shape, dtype=jnp.bool)\n        else:\n            out_masks[key] = jnp.asarray(observation.image_masks[key])\n\n    return Observation(\n        images=out_images,\n        image_masks=out_masks,\n        state=observation.state,\n        tokenized_prompt=observation.tokenized_prompt,\n        tokenized_prompt_mask=observation.tokenized_prompt_mask,\n        token_ar_mask=observation.token_ar_mask,\n        token_loss_mask=observation.token_loss_mask,\n    )\n\n\n@dataclasses.dataclass(frozen=True)\nclass BaseModelConfig(abc.ABC):\n    \"\"\"Configuration shared by all models. Specific models should inherit from this class, and implement the `create`\n    method to create the corresponding model.\n    \"\"\"\n\n    # Action space dimension.\n    action_dim: int\n    # Action sequence length.\n    action_horizon: int\n    # Tokenized prompt maximum length.\n    max_token_len: int\n\n    @property\n    @abc.abstractmethod\n    def model_type(self) -> ModelType:\n        \"\"\"The model type.\"\"\"\n\n    @abc.abstractmethod\n    def create(self, rng: at.KeyArrayLike) -> \"BaseModel\":\n        \"\"\"Create a new model, initializing parameters.\"\"\"\n\n    def load(self, params: at.Params, *, remove_extra_params: bool = True) -> \"BaseModel\":\n        \"\"\"Create a model with the given parameters.\"\"\"\n        model = nnx.eval_shape(self.create, jax.random.key(0))\n        graphdef, state = nnx.split(model)\n        if remove_extra_params:\n            params = ocp.transform_utils.intersect_trees(state.to_pure_dict(), params)\n        at.check_pytree_equality(expected=state.to_pure_dict(), got=params, check_shapes=True, check_dtypes=False)\n        state.replace_by_pure_dict(params)\n        return nnx.merge(graphdef, state)\n\n    def load_pytorch(self, train_config, weight_path: str):\n        logger.info(f\"train_config: {train_config}\")\n        model = pi0_pytorch.PI0Pytorch(config=train_config.model)\n        safetensors.torch.load_model(model, weight_path)\n        return model\n\n    @abc.abstractmethod\n    def inputs_spec(self, *, batch_size: int = 1) -> tuple[Observation, Actions]:\n        \"\"\"Returns the input specification for the model. Values are jax.ShapeDtypeStruct.\"\"\"\n\n    def fake_obs(self, batch_size: int = 1) -> Observation:\n        observation_spec, _ = self.inputs_spec(batch_size=batch_size)\n        return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), observation_spec)\n\n    def fake_act(self, batch_size: int = 1) -> Actions:\n        _, action_spec = self.inputs_spec(batch_size=batch_size)\n        return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), action_spec)\n\n\n@dataclasses.dataclass\nclass BaseModel(nnx.Module, abc.ABC):\n    \"\"\"Base class for all model implementations. Specific models should inherit from this class. They should call\n    super().__init__() to initialize the shared attributes (action_dim, action_horizon, and max_token_len).\n    \"\"\"\n\n    action_dim: int\n    action_horizon: int\n    max_token_len: int\n\n    @abc.abstractmethod\n    def compute_loss(\n        self,\n        rng: at.KeyArrayLike,\n        observation: Observation,\n        actions: Actions,\n        *,\n        train: bool = False,\n    ) -> at.Float[at.Array, \"*b ah\"]: ...\n\n    @abc.abstractmethod\n    def sample_actions(self, rng: at.KeyArrayLike, observation: Observation, **kwargs) -> Actions: ...\n\n\ndef restore_params(\n    params_path: pathlib.Path | str,\n    *,\n    restore_type: type[np.ndarray] | type[jax.Array] = jax.Array,\n    dtype: jnp.dtype | None = None,\n    sharding: jax.sharding.Sharding | None = None,\n) -> at.Params:\n    \"\"\"Restores unstructured params PyTree from a checkpoint.\n\n    This works with checkpoints saved with `save_state` during openpi training (see `training/checkpoints.py`) as\n    well as pre-trained checkpoints released for openpi.\n\n    Args:\n        params_path: The local path to the checkpoint directory.\n        restore_type: The type to restore the params as. Can be set to `np.ndarray` to load the params as a numpy array.\n        dtype: The dtype to restore all params as. If not provided, will use the original dtype from the checkpoint.\n        sharding: The sharding to use for the params. If not provided, the params will be replicated across all devices.\n\n    Returns:\n        The restored params.\n    \"\"\"\n    params_path = pathlib.Path(params_path).resolve() if not str(params_path).startswith(\"gs://\") else params_path\n\n    if restore_type is jax.Array and sharding is None:\n        mesh = jax.sharding.Mesh(jax.devices(), (\"x\",))\n        sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())\n\n    with ocp.PyTreeCheckpointer() as ckptr:\n        metadata = ckptr.metadata(params_path)\n        item = {\"params\": metadata[\"params\"]}\n\n        params = ckptr.restore(\n            params_path,\n            ocp.args.PyTreeRestore(\n                item=item,\n                restore_args=jax.tree.map(\n                    lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=restore_type, dtype=dtype), item\n                ),\n            ),\n        )[\"params\"]\n\n    # If the params were saved with `save_state` during openpi training, every key path will end with \"value\", which is\n    # added by `nnx.State`. We remove the \"value\" suffix here and always return what NNX calls a \"pure dict\".\n    flat_params = traverse_util.flatten_dict(params)\n    if all(kp[-1] == \"value\" for kp in flat_params):\n        flat_params = {kp[:-1]: v for kp, v in flat_params.items()}\n    return traverse_util.unflatten_dict(flat_params)\n"
  },
  {
    "path": "src/openpi/models/model_test.py",
    "content": "from flax import nnx\nimport jax\nimport pytest\n\nfrom openpi.models import model as _model\nfrom openpi.models import pi0_config\nfrom openpi.models import pi0_fast\nfrom openpi.shared import download\nfrom openpi.shared import nnx_utils\n\n\ndef test_pi0_model():\n    key = jax.random.key(0)\n    config = pi0_config.Pi0Config()\n    model = config.create(key)\n\n    batch_size = 2\n    obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)\n\n    loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)\n    assert loss.shape == (batch_size, config.action_horizon)\n\n    actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)\n    assert actions.shape == (batch_size, model.action_horizon, model.action_dim)\n\n\ndef test_pi0_lora_model():\n    key = jax.random.key(0)\n    config = pi0_config.Pi0Config(paligemma_variant=\"gemma_2b_lora\")\n    model = config.create(key)\n\n    batch_size = 2\n    obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)\n\n    loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)\n    assert loss.shape == (batch_size, config.action_horizon)\n\n    actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)\n    assert actions.shape == (batch_size, model.action_horizon, model.action_dim)\n\n\ndef test_pi0_fast_model():\n    key = jax.random.key(0)\n    config = pi0_fast.Pi0FASTConfig()\n    model = config.create(key)\n\n    batch_size = 2\n    obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)\n\n    loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)\n    assert loss.shape == (batch_size,)\n\n    actions = nnx_utils.module_jit(model.sample_actions)(key, obs)\n    assert actions.shape == (batch_size, 256)\n\n\ndef test_pi0_fast_lora_model():\n    key = jax.random.key(0)\n    config = pi0_fast.Pi0FASTConfig(paligemma_variant=\"gemma_2b_lora\")\n    model = config.create(key)\n\n    batch_size = 2\n    obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)\n\n    loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)\n    assert loss.shape == (batch_size,)\n\n    actions = nnx_utils.module_jit(model.sample_actions)(key, obs)\n    assert actions.shape == (batch_size, 256)\n\n    lora_filter = nnx_utils.PathRegex(\".*lora.*\")\n    model_state = nnx.state(model)\n\n    lora_state_elems = list(model_state.filter(lora_filter))\n    assert len(lora_state_elems) > 0\n\n\n@pytest.mark.manual\ndef test_model_restore():\n    key = jax.random.key(0)\n    config = pi0_config.Pi0Config()\n\n    batch_size = 2\n    obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)\n\n    model = config.load(\n        _model.restore_params(download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_base/params\"))\n    )\n\n    loss = model.compute_loss(key, obs, act)\n    assert loss.shape == (batch_size, config.action_horizon)\n\n    actions = model.sample_actions(key, obs, num_steps=10)\n    assert actions.shape == (batch_size, model.action_horizon, model.action_dim)\n"
  },
  {
    "path": "src/openpi/models/pi0.py",
    "content": "import logging\n\nimport einops\nimport flax.nnx as nnx\nimport flax.nnx.bridge as nnx_bridge\nimport jax\nimport jax.numpy as jnp\nfrom typing_extensions import override\n\nfrom openpi.models import model as _model\nfrom openpi.models import pi0_config\nimport openpi.models.gemma as _gemma\nimport openpi.models.siglip as _siglip\nfrom openpi.shared import array_typing as at\n\nlogger = logging.getLogger(\"openpi\")\n\n\ndef make_attn_mask(input_mask, mask_ar):\n    \"\"\"Adapted from big_vision.\n\n    Tokens can attend to valid inputs tokens which have a cumulative mask_ar\n    smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to\n    setup several types of attention, for example:\n\n      [[1 1 1 1 1 1]]: pure causal attention.\n\n      [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between\n          themselves and the last 3 tokens have a causal attention. The first\n          entry could also be a 1 without changing behaviour.\n\n      [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a\n          block can attend all previous blocks and all tokens on the same block.\n\n    Args:\n      input_mask: bool[B, N] true if its part of the input, false if padding.\n      mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on\n        it and false where it shares the same attention mask as the previous token.\n    \"\"\"\n    mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)\n    cumsum = jnp.cumsum(mask_ar, axis=1)\n    attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]\n    valid_mask = input_mask[:, None, :] * input_mask[:, :, None]\n    return jnp.logical_and(attn_mask, valid_mask)\n\n\n@at.typecheck\ndef posemb_sincos(\n    pos: at.Real[at.Array, \" b\"], embedding_dim: int, min_period: float, max_period: float\n) -> at.Float[at.Array, \"b {embedding_dim}\"]:\n    \"\"\"Computes sine-cosine positional embedding vectors for scalar positions.\"\"\"\n    if embedding_dim % 2 != 0:\n        raise ValueError(f\"embedding_dim ({embedding_dim}) must be divisible by 2\")\n\n    fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)\n    period = min_period * (max_period / min_period) ** fraction\n    sinusoid_input = jnp.einsum(\n        \"i,j->ij\",\n        pos,\n        1.0 / period * 2 * jnp.pi,\n        precision=jax.lax.Precision.HIGHEST,\n    )\n    return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)\n\n\nclass Pi0(_model.BaseModel):\n    def __init__(self, config: pi0_config.Pi0Config, rngs: nnx.Rngs):\n        super().__init__(config.action_dim, config.action_horizon, config.max_token_len)\n        self.pi05 = config.pi05\n        paligemma_config = _gemma.get_config(config.paligemma_variant)\n        action_expert_config = _gemma.get_config(config.action_expert_variant)\n        # TODO: rewrite gemma in NNX. For now, use bridge.\n        llm = nnx_bridge.ToNNX(\n            _gemma.Module(\n                configs=[paligemma_config, action_expert_config],\n                embed_dtype=config.dtype,\n                adarms=config.pi05,\n            )\n        )\n        llm.lazy_init(rngs=rngs, method=\"init\", use_adarms=[False, True] if config.pi05 else [False, False])\n        img = nnx_bridge.ToNNX(\n            _siglip.Module(\n                num_classes=paligemma_config.width,\n                variant=\"So400m/14\",\n                pool_type=\"none\",\n                scan=True,\n                dtype_mm=config.dtype,\n            )\n        )\n        img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)\n        self.PaliGemma = nnx.Dict(llm=llm, img=img)\n        self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)\n        if config.pi05:\n            self.time_mlp_in = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)\n            self.time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)\n        else:\n            self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)\n            self.action_time_mlp_in = nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngs=rngs)\n            self.action_time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)\n        self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs)\n\n        # This attribute gets automatically set by model.train() and model.eval().\n        self.deterministic = True\n\n    @at.typecheck\n    def embed_prefix(\n        self, obs: _model.Observation\n    ) -> tuple[at.Float[at.Array, \"b s emb\"], at.Bool[at.Array, \"b s\"], at.Bool[at.Array, \" s\"]]:\n        input_mask = []\n        ar_mask = []\n        tokens = []\n        # embed images\n        for name in obs.images:\n            image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)\n\n            tokens.append(image_tokens)\n            input_mask.append(\n                einops.repeat(\n                    obs.image_masks[name],\n                    \"b -> b s\",\n                    s=image_tokens.shape[1],\n                )\n            )\n            # image tokens attend to each other\n            ar_mask += [False] * image_tokens.shape[1]\n\n        # add language (aka tokenized inputs)\n        if obs.tokenized_prompt is not None:\n            tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method=\"embed\")\n            tokens.append(tokenized_inputs)\n            input_mask.append(obs.tokenized_prompt_mask)\n            # full attention between image and language inputs\n            ar_mask += [False] * tokenized_inputs.shape[1]\n        tokens = jnp.concatenate(tokens, axis=1)\n        input_mask = jnp.concatenate(input_mask, axis=1)\n        ar_mask = jnp.array(ar_mask)\n        return tokens, input_mask, ar_mask\n\n    @at.typecheck\n    def embed_suffix(\n        self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, \" b\"]\n    ) -> tuple[\n        at.Float[at.Array, \"b s emb\"],\n        at.Bool[at.Array, \"b s\"],\n        at.Bool[at.Array, \" s\"],\n        at.Float[at.Array, \"b emb\"] | None,\n    ]:\n        input_mask = []\n        ar_mask = []\n        tokens = []\n        if not self.pi05:\n            # add a single state token\n            state_token = self.state_proj(obs.state)[:, None, :]\n            tokens.append(state_token)\n            input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_))\n            # image/language inputs do not attend to state or actions\n            ar_mask += [True]\n\n        action_tokens = self.action_in_proj(noisy_actions)\n        # embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]\n        time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0)\n        if self.pi05:\n            # time MLP (for adaRMS)\n            time_emb = self.time_mlp_in(time_emb)\n            time_emb = nnx.swish(time_emb)\n            time_emb = self.time_mlp_out(time_emb)\n            time_emb = nnx.swish(time_emb)\n            action_expert_tokens = action_tokens\n            adarms_cond = time_emb\n        else:\n            # mix timestep + action information using an MLP (no adaRMS)\n            time_tokens = einops.repeat(time_emb, \"b emb -> b s emb\", s=self.action_horizon)\n            action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)\n            action_time_tokens = self.action_time_mlp_in(action_time_tokens)\n            action_time_tokens = nnx.swish(action_time_tokens)\n            action_time_tokens = self.action_time_mlp_out(action_time_tokens)\n            action_expert_tokens = action_time_tokens\n            adarms_cond = None\n        tokens.append(action_expert_tokens)\n        input_mask.append(jnp.ones(action_expert_tokens.shape[:2], dtype=jnp.bool_))\n        # image/language/state inputs do not attend to action tokens\n        ar_mask += [True] + ([False] * (self.action_horizon - 1))\n        tokens = jnp.concatenate(tokens, axis=1)\n        input_mask = jnp.concatenate(input_mask, axis=1)\n        ar_mask = jnp.array(ar_mask)\n        return tokens, input_mask, ar_mask, adarms_cond\n\n    @override\n    def compute_loss(\n        self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False\n    ) -> at.Float[at.Array, \"*b ah\"]:\n        preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)\n        observation = _model.preprocess_observation(preprocess_rng, observation, train=train)\n\n        batch_shape = actions.shape[:-2]\n        noise = jax.random.normal(noise_rng, actions.shape)\n        time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001\n        time_expanded = time[..., None, None]\n        x_t = time_expanded * noise + (1 - time_expanded) * actions\n        u_t = noise - actions\n\n        # one big forward pass of prefix + suffix at once\n        prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)\n        suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(observation, x_t, time)\n        input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1)\n        ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0)\n        attn_mask = make_attn_mask(input_mask, ar_mask)\n        positions = jnp.cumsum(input_mask, axis=1) - 1\n        (prefix_out, suffix_out), _ = self.PaliGemma.llm(\n            [prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions, adarms_cond=[None, adarms_cond]\n        )\n        v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])\n\n        return jnp.mean(jnp.square(v_t - u_t), axis=-1)\n\n    @override\n    def sample_actions(\n        self,\n        rng: at.KeyArrayLike,\n        observation: _model.Observation,\n        *,\n        num_steps: int | at.Int[at.Array, \"\"] = 10,\n        noise: at.Float[at.Array, \"b ah ad\"] | None = None,\n    ) -> _model.Actions:\n        observation = _model.preprocess_observation(None, observation, train=False)\n        # note that we use the convention more common in diffusion literature, where t=1 is noise and t=0 is the target\n        # distribution. yes, this is the opposite of the pi0 paper, and I'm sorry.\n        dt = -1.0 / num_steps\n        batch_size = observation.state.shape[0]\n        if noise is None:\n            noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))\n\n        # first fill KV cache with a forward pass of the prefix\n        prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)\n        prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)\n        positions = jnp.cumsum(prefix_mask, axis=1) - 1\n        _, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions)\n\n        def step(carry):\n            x_t, time = carry\n            suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(\n                observation, x_t, jnp.broadcast_to(time, batch_size)\n            )\n            # `suffix_attn_mask` is shape (b, suffix_len, suffix_len) indicating how the suffix tokens can attend to each\n            # other\n            suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)\n            # `prefix_attn_mask` is shape (b, suffix_len, prefix_len) indicating how the suffix tokens can attend to the\n            # prefix tokens\n            prefix_attn_mask = einops.repeat(prefix_mask, \"b p -> b s p\", s=suffix_tokens.shape[1])\n            # `combined_mask` is shape (b, suffix_len, prefix_len + suffix_len) indicating how the suffix tokens (which\n            # generate the queries) can attend to the full prefix + suffix sequence (which generates the keys and values)\n            full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)\n            assert full_attn_mask.shape == (\n                batch_size,\n                suffix_tokens.shape[1],\n                prefix_tokens.shape[1] + suffix_tokens.shape[1],\n            )\n            # `positions` is shape (b, suffix_len) indicating the positions of the suffix tokens\n            positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1\n\n            (prefix_out, suffix_out), _ = self.PaliGemma.llm(\n                [None, suffix_tokens],\n                mask=full_attn_mask,\n                positions=positions,\n                kv_cache=kv_cache,\n                adarms_cond=[None, adarms_cond],\n            )\n            assert prefix_out is None\n            v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])\n\n            return x_t + dt * v_t, time + dt\n\n        def cond(carry):\n            x_t, time = carry\n            # robust to floating-point error\n            return time >= -dt / 2\n\n        x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))\n        return x_0\n"
  },
  {
    "path": "src/openpi/models/pi0_config.py",
    "content": "import dataclasses\nfrom typing import TYPE_CHECKING\n\nimport flax.nnx as nnx\nimport jax\nimport jax.numpy as jnp\nfrom typing_extensions import override\n\nfrom openpi.models import model as _model\nimport openpi.models.gemma as _gemma\nfrom openpi.shared import array_typing as at\nimport openpi.shared.nnx_utils as nnx_utils\n\nif TYPE_CHECKING:\n    from openpi.models.pi0 import Pi0\n\n\n@dataclasses.dataclass(frozen=True)\nclass Pi0Config(_model.BaseModelConfig):\n    dtype: str = \"bfloat16\"\n    paligemma_variant: _gemma.Variant = \"gemma_2b\"\n    action_expert_variant: _gemma.Variant = \"gemma_300m\"\n\n    # Set the model specific defaults.\n    action_dim: int = 32\n    action_horizon: int = 50\n    max_token_len: int = None  # type: ignore\n    # Pi05 has two differences from Pi0:\n    # - the state input is part of the discrete language tokens rather than a continuous input that is part of the suffix\n    # - the action expert uses adaRMSNorm to inject the flow matching timestep\n    pi05: bool = False\n    # This config option is not used directly by the model, but it is read by the ModelTransformFactory.\n    discrete_state_input: bool = None  # type: ignore\n\n    pytorch_compile_mode: str | None = \"max-autotune\"\n\n    def __post_init__(self):\n        if self.max_token_len is None:\n            object.__setattr__(self, \"max_token_len\", 200 if self.pi05 else 48)\n        if self.discrete_state_input is None:\n            object.__setattr__(self, \"discrete_state_input\", self.pi05)\n        if self.pytorch_compile_mode is not None:\n            assert self.pytorch_compile_mode in [\n                \"default\",\n                \"reduce-overhead\",\n                \"max-autotune\",\n                \"max-autotune-no-cudagraphs\",\n            ]\n\n    @property\n    @override\n    def model_type(self) -> _model.ModelType:\n        if self.pi05:\n            return _model.ModelType.PI05\n        return _model.ModelType.PI0\n\n    @override\n    def create(self, rng: at.KeyArrayLike) -> \"Pi0\":\n        from openpi.models.pi0 import Pi0\n\n        return Pi0(self, rngs=nnx.Rngs(rng))\n\n    @override\n    def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]:\n        image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)\n        image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)\n\n        with at.disable_typechecking():\n            observation_spec = _model.Observation(\n                images={\n                    \"base_0_rgb\": image_spec,\n                    \"left_wrist_0_rgb\": image_spec,\n                    \"right_wrist_0_rgb\": image_spec,\n                },\n                image_masks={\n                    \"base_0_rgb\": image_mask_spec,\n                    \"left_wrist_0_rgb\": image_mask_spec,\n                    \"right_wrist_0_rgb\": image_mask_spec,\n                },\n                state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),\n                tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),\n                tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),\n            )\n        action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)\n\n        return observation_spec, action_spec\n\n    def get_freeze_filter(self) -> nnx.filterlib.Filter:\n        \"\"\"Returns the freeze filter based on the model config.\"\"\"\n        filters = []\n        has_lora = False\n        gemma_params_filter = nnx_utils.PathRegex(\".*llm.*\")\n        action_expert_params_filter = nnx_utils.PathRegex(\".*llm.*_1.*\")\n        if \"lora\" in self.paligemma_variant:\n            filters.append(\n                gemma_params_filter,\n            )\n            if \"lora\" not in self.action_expert_variant:\n                # If only freeze gemma params, exclude action expert params.\n                filters.append(\n                    nnx.Not(action_expert_params_filter),\n                )\n            has_lora = True\n        elif \"lora\" in self.action_expert_variant:\n            filters.append(\n                action_expert_params_filter,\n            )\n            has_lora = True\n\n        if has_lora:\n            # If any lora is used, exclude all lora params.\n            filters.append(\n                nnx.Not(nnx_utils.PathRegex(\".*lora.*\")),\n            )\n        if not filters:\n            return nnx.Nothing\n        return nnx.All(*filters)\n"
  },
  {
    "path": "src/openpi/models/pi0_fast.py",
    "content": "import dataclasses\nimport logging\nfrom typing import Any\n\nimport einops\nimport flax.nnx as nnx\nimport flax.nnx.bridge as nnx_bridge\nimport jax\nimport jax.numpy as jnp\nfrom typing_extensions import override\n\nfrom openpi.models import model as _model\nimport openpi.models.gemma_fast as _gemma\nimport openpi.models.siglip as _siglip\nfrom openpi.shared import array_typing as at\nimport openpi.shared.nnx_utils as nnx_utils\n\nlogger = logging.getLogger(\"openpi\")\n\nPALIGEMMA_EOS_TOKEN = 1\n\n\ndef make_attn_mask(input_mask, mask_ar):\n    \"\"\"Adapted from big_vision.\n\n    Tokens can attend to valid inputs tokens which have a cumulative mask_ar\n    smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to\n    setup several types of attention, for example:\n\n      [[1 1 1 1 1 1]]: pure causal attention.\n\n      [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between\n          themselves and the last 3 tokens have a causal attention. The first\n          entry could also be a 1 without changing behaviour.\n\n      [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a\n          block can attend all previous blocks and all tokens on the same block.\n\n    Args:\n      input_mask: bool[B, N] true if its part of the input, false if padding.\n      mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on\n        it and false where it shares the same attention mask as the previous token.\n    \"\"\"\n    mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)\n    cumsum = jnp.cumsum(mask_ar, axis=1)\n    attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]\n    valid_mask = input_mask[:, None, :] * input_mask[:, :, None]\n    return jnp.logical_and(attn_mask, valid_mask)\n\n\n@jax.vmap\ndef left_to_right_align(x, input_mask, attn_mask):\n    \"\"\"Converts input from left-align to right-aligned.\"\"\"\n    # Due to vmap, this is operating in a single example (not batch level).\n    assert x.ndim == 2\n    assert input_mask.ndim == 1\n    assert attn_mask.ndim == 2\n    assert x.shape[0] == input_mask.shape[0]\n    assert attn_mask.shape[0] == attn_mask.shape[1], attn_mask.shape\n    seqlen = jnp.max(input_mask * jnp.arange(input_mask.shape[0])) + 1\n    x = jnp.roll(x, -seqlen, axis=0)\n    input_mask = jnp.roll(input_mask, -seqlen, axis=0)\n    attn_mask = jnp.roll(attn_mask, -seqlen, axis=(0, 1))\n    return x, input_mask, attn_mask\n\n\ndef put_along_last_axis(arr, indices, values):\n    \"\"\"Like np.put_along_axis(..., axis=-1), since jax is missing it.\"\"\"\n    assert arr.ndim == indices.ndim == values.ndim, (arr.ndim, indices.ndim, values.ndim)\n    onehot = jax.nn.one_hot(indices, arr.shape[-1], dtype=values.dtype)\n    put_mask = jnp.einsum(\"...i,...in->...n\", jnp.ones(values.shape, jnp.int32), onehot)\n    put_values = jnp.einsum(\"...i,...in->...n\", values, onehot)\n    return jnp.where(put_mask, put_values, arr)\n\n\n@dataclasses.dataclass(frozen=True)\nclass Pi0FASTConfig(_model.BaseModelConfig):\n    dtype: str = \"bfloat16\"\n    paligemma_variant: _gemma.Variant = \"gemma_2b\"\n\n    # Set the model specific defaults.\n    action_dim: int = 32\n    action_horizon: int = 32\n    max_token_len: int = 250\n\n    # Tokenizer for the fast model.\n    fast_model_tokenizer: Any | None = None\n    # Keyword arguments for the fast model tokenizer.\n    fast_model_tokenizer_kwargs: dict[str, Any] | None = None\n\n    @property\n    @override\n    def model_type(self) -> _model.ModelType:\n        return _model.ModelType.PI0_FAST\n\n    @override\n    def create(self, rng: at.KeyArrayLike) -> \"Pi0FAST\":\n        return Pi0FAST(self, rngs=nnx.Rngs(rng))\n\n    @override\n    def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]:\n        image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)\n        image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)\n\n        with at.disable_typechecking():\n            observation_spec = _model.Observation(\n                images={\n                    \"base_0_rgb\": image_spec,\n                    \"base_1_rgb\": image_spec,\n                    \"wrist_0_rgb\": image_spec,\n                },\n                image_masks={\n                    \"base_0_rgb\": image_mask_spec,\n                    \"base_1_rgb\": image_mask_spec,\n                    \"wrist_0_rgb\": image_mask_spec,\n                },\n                state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),\n                tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),\n                tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),\n                token_ar_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),\n                token_loss_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.bool_),\n            )\n        action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)\n\n        return observation_spec, action_spec\n\n    def get_freeze_filter(self) -> nnx.filterlib.Filter:\n        \"\"\"Returns the freeze filter based on the model config.\"\"\"\n        if \"lora\" in self.paligemma_variant:\n            return nnx.All(nnx_utils.PathRegex(\".*llm.*\"), nnx.Not(nnx_utils.PathRegex(\".*lora.*\")))\n        return nnx.Nothing\n\n\nclass Pi0FAST(_model.BaseModel):\n    def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs):\n        super().__init__(config.action_dim, config.action_horizon, config.max_token_len)\n        paligemma_config = _gemma.get_config(config.paligemma_variant)\n        # TODO: rewrite gemma in NNX. For now, use bridge.\n        llm = nnx_bridge.ToNNX(\n            _gemma.Module(\n                **paligemma_config,\n                embed_dtype=config.dtype,\n                cache_dtype=config.dtype,\n            )\n        )\n        llm.lazy_init(rngs=rngs, method=\"init\")\n        img = nnx_bridge.ToNNX(\n            _siglip.Module(\n                num_classes=paligemma_config.width,\n                variant=\"So400m/14\",\n                pool_type=\"none\",\n                scan=True,\n                dtype_mm=config.dtype,\n            )\n        )\n        img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)\n        self.PaliGemma = nnx.Dict(llm=llm, img=img)\n\n    @at.typecheck\n    def embed_inputs(\n        self, obs: _model.Observation\n    ) -> tuple[at.Float[at.Array, \"b s emb\"], at.Bool[at.Array, \"b s\"], at.Int[at.Array, \"b s\"]]:\n        input_mask = []\n        ar_mask = []\n        token_embeddings = []\n        # embed images\n        for name in obs.images:\n            image_token_embeddings, _ = self.PaliGemma.img(obs.images[name], train=False)\n\n            token_embeddings.append(image_token_embeddings)\n            input_mask.append(\n                einops.repeat(\n                    obs.image_masks[name],\n                    \"b -> b s\",\n                    s=image_token_embeddings.shape[1],\n                )\n            )\n            # image tokens attend to each other --> AR mask = 0\n            ar_mask.append(0 * input_mask[-1])\n\n        # add tokenized inputs\n        assert obs.tokenized_prompt is not None, \"Tokenized prompt is required\"\n        assert obs.tokenized_prompt_mask is not None, \"Tokenized prompt mask is required\"\n        assert obs.token_ar_mask is not None, \"Token auto-regressive mask is required\"\n        tokenized_inputs_embeddings = self.PaliGemma.llm(obs.tokenized_prompt, embed_only=True)\n        token_embeddings.append(tokenized_inputs_embeddings)\n        input_mask.append(obs.tokenized_prompt_mask)\n        ar_mask.append(obs.token_ar_mask)\n\n        # return embeddings, input mask, and ar mask\n        return (\n            jnp.concatenate(token_embeddings, axis=1),\n            jnp.concatenate(input_mask, axis=1),\n            jnp.concatenate(ar_mask, axis=1),\n        )\n\n    @override\n    def compute_loss(\n        self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False\n    ) -> at.Float[at.Array, \"*b ah\"]:\n        observation = _model.preprocess_observation(\n            rng, observation, train=train, image_keys=list(observation.images.keys())\n        )\n\n        # Compute inputs: one big forward pass of prefix + suffix at once\n        input_token_embeddings, input_mask, ar_mask = self.embed_inputs(observation)\n        attn_mask = make_attn_mask(input_mask, ar_mask)\n\n        # Compute one-hot targets: we predict *next* token, so shift the input tokens by one.\n        targets = jax.nn.one_hot(\n            observation.tokenized_prompt[:, 1:],\n            self.PaliGemma.llm.module.vocab_size,\n        )\n\n        # Each input predicts *next* token, so we don't input the last token.\n        pre_logits, _, _ = self.PaliGemma.llm(\n            embedded_prefix=input_token_embeddings[:, :-1],\n            mask=attn_mask[:, :-1, :-1],\n            return_prelogits=True,\n        )\n\n        # Only decode logits for the target tokens to save memory\n        # (decoding matmul is large because it is a seq_len x vocab_size dense layer).\n        logits, _ = self.PaliGemma.llm(\n            pre_logits=pre_logits[:, -targets.shape[1] :],\n        )\n        logp = jax.nn.log_softmax(logits, axis=-1)\n\n        # Compute CE loss on token targets\n        assert observation.token_loss_mask is not None, \"Token loss mask is required\"\n        loss_mask = observation.token_loss_mask[:, 1:]\n        token_pplx = jnp.sum(targets * logp, axis=-1)\n        return -jnp.sum(token_pplx * loss_mask, axis=-1) / jnp.clip(jnp.sum(loss_mask, -1), 1)\n\n    @override\n    def sample_actions(\n        self,\n        rng: at.KeyArrayLike,\n        observation: _model.Observation,\n        *,\n        max_decoding_steps: int | at.Int[at.Array, \"\"] = 256,\n        temperature: float = 0.0,\n    ) -> _model.Actions:\n        # TODO: this is a hack to get the image keys.\n        observation = _model.preprocess_observation(\n            None, observation, train=False, image_keys=list(observation.images.keys())\n        )\n\n        # embed inputs\n        prefix_token_embeddings, prefix_mask, prefix_ar_mask = self.embed_inputs(observation)\n        prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)\n\n        # left to right align all input token sequences\n        prefix_token_embeddings, prefix_mask, prefix_attn_mask = left_to_right_align(\n            prefix_token_embeddings, prefix_mask, prefix_attn_mask\n        )\n        prefill_size = prefix_token_embeddings.shape[1]\n        prefill_len = jnp.sum(prefix_mask, axis=-1)\n        prefix_start = prefill_size - prefill_len\n\n        # first fill KV cache with a forward pass of the prefix\n        # pad attention mask to set the size of the KV cache (prefill_size + max_decoding_steps)\n        prefix_attn_mask = jnp.pad(prefix_attn_mask, ((0, 0), (0, 0), (0, max_decoding_steps)))\n        prefix_positions = jnp.cumsum(prefix_mask, axis=-1) - 1\n        prefix_logits, kv_cache, _ = self.PaliGemma.llm(\n            embedded_prefix=prefix_token_embeddings, mask=prefix_attn_mask, positions=prefix_positions, decode=True\n        )\n\n        # prepare decoding -- final logit decodes the first token\n        last_logit = prefix_logits[:, -1:]\n        output_tokens = jnp.zeros((last_logit.shape[0], max_decoding_steps))\n\n        def step(carry):\n            rng, last_logit, output_tokens, cache, _, step = carry\n\n            # Sample token from last logit\n            # Split RNG for this step\n            rng, rng_step = jax.random.split(rng)\n            token = jax.lax.cond(\n                temperature > 0.0,\n                lambda _: jax.random.categorical(rng_step, last_logit / temperature, axis=-1),\n                lambda _: jnp.argmax(last_logit, axis=-1),\n                operand=None,\n            )\n            output_tokens = put_along_last_axis(output_tokens, jnp.broadcast_to(step, (token.shape[0], 1)), token)\n\n            # Check for early stopping --> stop if all batch elements have EOS token\n            has_eos = jnp.any(token == PALIGEMMA_EOS_TOKEN, axis=-1)\n            all_eos = jnp.all(has_eos)\n\n            # Decode one step\n            token_embedding = self.PaliGemma.llm(token, embed_only=True)\n            positions = prefill_len[:, None] + step + 1\n            mask = jnp.logical_and(\n                jnp.arange(prefill_size + max_decoding_steps)[None, None, :] >= prefix_start[:, None, None],\n                jnp.arange(prefill_size + max_decoding_steps)[None, None, :]\n                < (jnp.broadcast_to(prefill_size + step + 1, (prefix_start.shape[0], 1, 1))),\n            )\n            last_logit, kv_cache, _ = self.PaliGemma.llm(\n                embedded_prefix=token_embedding, mask=mask, positions=positions, decode=True, kv_cache=cache\n            )\n\n            return rng, last_logit, output_tokens, kv_cache, all_eos, step + 1\n\n        def cond(carry):\n            _, _, _, _, all_eos, step = carry\n            return (~all_eos) & (step < max_decoding_steps)\n\n        # Use lax.while_loop so we can jit the full decoding loop.\n        _, _, output_tokens, _, _, _ = jax.lax.while_loop(\n            cond, step, (rng, last_logit, output_tokens, kv_cache, False, 0)\n        )\n        return output_tokens\n"
  },
  {
    "path": "src/openpi/models/pi0_test.py",
    "content": "import flax.nnx as nnx\nimport jax\n\nimport openpi.models.pi0_config as _pi0_config\n\n\ndef _get_frozen_state(config: _pi0_config.Pi0Config) -> nnx.State:\n    abstract_model = nnx.eval_shape(config.create, jax.random.key(0))\n\n    freeze_filter = config.get_freeze_filter()\n    return nnx.state(abstract_model, nnx.All(nnx.Param, freeze_filter)).flat_state()\n\n\ndef test_pi0_full_finetune():\n    config = _pi0_config.Pi0Config()\n    state = _get_frozen_state(config)\n    assert len(state) == 0\n\n\ndef test_pi0_gemma_lora():\n    config = _pi0_config.Pi0Config(paligemma_variant=\"gemma_2b_lora\")\n    state = _get_frozen_state(config)\n    assert len(state) == 9\n    assert all(\"lora\" not in p for p in state)\n    assert all(\"llm\" in p for p in state)\n    assert all(\"_1\" not in p for p in state)\n\n\ndef test_pi0_action_expert_lora():\n    config = _pi0_config.Pi0Config(action_expert_variant=\"gemma_300m_lora\")\n    state = _get_frozen_state(config)\n    # excluding embedder, rest of the params should be same as gemma_lora.\n    assert len(state) == 8\n    assert all(\"lora\" not in p for p in state)\n    assert all(\"llm\" in p for p in state)\n    # all frozen params should have _1 in their path since it's the action expert.\n    assert all(any(\"_1\" in p for p in path) for path in state)\n\n\ndef test_pi0_all_lora():\n    config = _pi0_config.Pi0Config(paligemma_variant=\"gemma_2b_lora\", action_expert_variant=\"gemma_300m_lora\")\n    state = _get_frozen_state(config)\n    # sum of gemma_lora and action_expert_lora's frozen params.\n    assert len(state) == 17\n    assert all(\"lora\" not in p for p in state)\n    assert all(\"llm\" in p for p in state)\n"
  },
  {
    "path": "src/openpi/models/siglip.py",
    "content": "# Copyright 2024 Big Vision Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"A refactored and simplified ViT adoptation for Pi, taken from big_vision.\"\"\"\n\nfrom collections.abc import Sequence\n\nimport flax.linen as nn\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nimport openpi.training.sharding as sharding\n\n\ndef posemb_sincos_2d(h, w, width, temperature=10_000.0, dtype=jnp.float32):\n    \"\"\"Follows the MoCo v3 logic.\"\"\"\n    y, x = jnp.mgrid[:h, :w]\n\n    assert width % 4 == 0, \"Width must be mult of 4 for sincos posemb\"\n    omega = jnp.arange(width // 4) / (width // 4 - 1)\n    omega = 1.0 / (temperature**omega)\n    y = jnp.einsum(\"m,d->md\", y.flatten(), omega)\n    x = jnp.einsum(\"m,d->md\", x.flatten(), omega)\n    pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1)\n    return jnp.asarray(pe, dtype)[None, :, :]\n\n\ndef get_posemb(self, typ, seqshape, width, name, dtype=jnp.float32):\n    if typ == \"learn\":\n        return self.param(\n            name,\n            nn.initializers.normal(stddev=1 / np.sqrt(width)),\n            (1, np.prod(seqshape), width),\n            dtype,\n        )\n    if typ == \"sincos2d\":\n        return posemb_sincos_2d(*seqshape, width, dtype=dtype)\n    raise ValueError(f\"Unknown posemb type: {typ}\")\n\n\nclass MlpBlock(nn.Module):\n    \"\"\"Transformer MLP / feed-forward block.\"\"\"\n\n    mlp_dim: int | None = None  # Defaults to 4x input dim\n    dropout: float = 0.0\n    dtype_mm: str = \"float32\"\n\n    @nn.compact\n    def __call__(self, x, deterministic=True):  # noqa: FBT002\n        \"\"\"Applies Transformer MlpBlock module.\"\"\"\n        inits = {\n            \"kernel_init\": nn.initializers.xavier_uniform(),\n            \"bias_init\": nn.initializers.normal(stddev=1e-6),\n        }\n\n        _, _, d = x.shape  # n,l,d\n        x = nn.Dense(self.mlp_dim or 4 * d, dtype=self.dtype_mm, **inits)(x)\n        x = nn.gelu(x)\n        x = nn.Dropout(rate=self.dropout)(x, deterministic)\n        return nn.Dense(d, dtype=self.dtype_mm, **inits)(x)\n\n\nclass Encoder1DBlock(nn.Module):\n    \"\"\"Single transformer encoder block (MHSA + MLP).\"\"\"\n\n    mlp_dim: int | None = None  # Defaults to 4x input dim\n    num_heads: int = 12\n    dropout: float = 0.0\n    dtype_mm: str = \"float32\"\n\n    @nn.compact\n    def __call__(self, x, deterministic=True):  # noqa: FBT002\n        out = {}\n        x = sharding.activation_sharding_constraint(x)\n        y = nn.LayerNorm(dtype=self.dtype_mm)(x)\n        y = out[\"sa\"] = nn.MultiHeadDotProductAttention(\n            num_heads=self.num_heads,\n            kernel_init=nn.initializers.xavier_uniform(),\n            deterministic=deterministic,\n            dtype=self.dtype_mm,\n        )(y, y)\n        y = sharding.activation_sharding_constraint(y)\n        y = nn.Dropout(rate=self.dropout)(y, deterministic)\n        x = out[\"+sa\"] = x + y\n\n        y = nn.LayerNorm(dtype=self.dtype_mm)(x)\n        y = out[\"mlp\"] = MlpBlock(\n            mlp_dim=self.mlp_dim,\n            dropout=self.dropout,\n            dtype_mm=self.dtype_mm,\n        )(y, deterministic)\n        y = sharding.activation_sharding_constraint(y)\n        y = nn.Dropout(rate=self.dropout)(y, deterministic)\n        x = out[\"+mlp\"] = x + y\n        x = sharding.activation_sharding_constraint(x)\n        return x, out\n\n\nclass Encoder(nn.Module):\n    \"\"\"Transformer Model Encoder for sequence to sequence translation.\"\"\"\n\n    depth: int\n    mlp_dim: int | None = None  # Defaults to 4x input dim\n    num_heads: int = 12\n    dropout: float = 0.0\n    scan: bool = False\n    remat_policy: str = \"nothing_saveable\"\n    dtype_mm: str = \"float32\"\n\n    @nn.compact\n    def __call__(self, x, deterministic=True):  # noqa: FBT002\n        out = {}\n\n        if self.scan:\n            block = nn.remat(\n                Encoder1DBlock,\n                prevent_cse=False,\n                static_argnums=(2,),  # 0=self, 2=deterministic\n                policy=getattr(jax.checkpoint_policies, self.remat_policy, None),\n            )\n            x, scan_out = nn.scan(\n                block,\n                variable_axes={\"params\": 0},\n                split_rngs={\"params\": True, \"dropout\": True},\n                in_axes=nn.broadcast,\n                length=self.depth,\n            )(\n                name=\"encoderblock\",\n                dtype_mm=self.dtype_mm,\n                mlp_dim=self.mlp_dim,\n                num_heads=self.num_heads,\n                dropout=self.dropout,\n            )(x, deterministic)\n            for lyr in range(self.depth):\n                out[f\"block{lyr:02d}\"] = jax.tree.map(lambda o, lyr=lyr: o[lyr], scan_out)\n        else:\n            # Input Encoder\n            for lyr in range(self.depth):\n                block_cur = Encoder1DBlock(\n                    name=f\"encoderblock_{lyr}\",\n                    dtype_mm=self.dtype_mm,\n                    mlp_dim=self.mlp_dim,\n                    num_heads=self.num_heads,\n                    dropout=self.dropout,\n                )\n                x, out[f\"block{lyr:02d}\"] = block_cur(x, deterministic)\n            out[\"pre_ln\"] = x  # Alias for last block, but without the number in it.\n\n        return nn.LayerNorm(name=\"encoder_norm\", dtype=self.dtype_mm)(x), out\n\n\nclass MAPHead(nn.Module):\n    \"\"\"Multihead Attention Pooling.\"\"\"\n\n    mlp_dim: int | None = None  # Defaults to 4x input dim\n    num_heads: int = 12\n    dtype_mm: str = \"float32\"\n\n    @nn.compact\n    def __call__(self, x):\n        n, _, d = x.shape  # n,l,d\n        probe = self.param(\"probe\", nn.initializers.xavier_uniform(), (1, 1, d), x.dtype)\n        probe = jnp.tile(probe, [n, 1, 1])\n\n        x = nn.MultiHeadDotProductAttention(\n            num_heads=self.num_heads,\n            dtype=self.dtype_mm,\n            kernel_init=nn.initializers.xavier_uniform(),\n        )(probe, x)\n\n        y = nn.LayerNorm(dtype=self.dtype_mm)(x)\n        x = x + MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype_mm)(y)\n        return x[:, 0]\n\n\nclass _Module(nn.Module):\n    \"\"\"ViT model.\"\"\"\n\n    num_classes: int | None = None\n    patch_size: Sequence[int] = (16, 16)\n    width: int = 768\n    depth: int = 12\n    mlp_dim: int | None = None  # Defaults to 4x input dim\n    num_heads: int = 12\n    posemb: str = \"learn\"  # Can also be \"sincos2d\"\n    rep_size: int | bool = False\n    dropout: float = 0.0\n    pool_type: str = \"gap\"  # Can also be \"map\" or \"tok\"\n    head_zeroinit: bool = True\n    scan: bool = False\n    # or \"dots_with_no_batch_dims_saveable\" for more speed (memory costly)\n    remat_policy: str = \"nothing_saveable\"\n    dtype_mm: str = \"float32\"\n\n    @nn.compact\n    def __call__(self, image, *, train=False):\n        out = {}\n\n        # Kevin edit: do patch extraction and posemb in float32,\n        # because I feel like it's a bit safer.\n        image = jnp.asarray(image, jnp.float32)\n\n        # Patch extraction\n        x = out[\"stem\"] = nn.Conv(\n            self.width,\n            self.patch_size,\n            strides=self.patch_size,\n            padding=\"VALID\",\n            name=\"embedding\",\n            dtype=jnp.float32,\n        )(image)\n\n        n, h, w, c = x.shape\n        x = jnp.reshape(x, [n, h * w, c])\n\n        # Add posemb before adding extra token.\n        x = out[\"with_posemb\"] = x + get_posemb(self, self.posemb, (h, w), c, \"pos_embedding\", jnp.float32)\n\n        if self.pool_type == \"tok\":\n            cls = self.param(\"cls\", nn.initializers.zeros, (1, 1, c), x.dtype)\n            x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1)\n\n        n, _, c = x.shape  # n,l,d\n        x = nn.Dropout(rate=self.dropout)(x, not train)\n\n        # Kevin edit: now cast back to dtype_mm (potentially half precision)\n        x = x.astype(self.dtype_mm)\n\n        x, out[\"encoder\"] = Encoder(\n            depth=self.depth,\n            mlp_dim=self.mlp_dim,\n            num_heads=self.num_heads,\n            dropout=self.dropout,\n            scan=self.scan,\n            remat_policy=self.remat_policy,\n            dtype_mm=self.dtype_mm,\n            name=\"Transformer\",\n        )(x, deterministic=not train)\n        encoded = out[\"encoded\"] = x\n\n        if self.pool_type == \"map\":\n            x = out[\"head_input\"] = MAPHead(\n                num_heads=self.num_heads,\n                mlp_dim=self.mlp_dim,\n                dtype=self.dtype_mm,\n            )(x)\n        elif self.pool_type == \"gap\":\n            x = out[\"head_input\"] = jnp.mean(x, axis=1)\n        elif self.pool_type == \"0\":\n            x = out[\"head_input\"] = x[:, 0]\n        elif self.pool_type == \"tok\":\n            x = out[\"head_input\"] = x[:, 0]\n            encoded = encoded[:, 1:]\n        elif self.pool_type == \"none\":\n            pass\n        else:\n            raise ValueError(f\"Unknown pool type: '{self.pool_type}'\")\n\n        x_2d = jnp.reshape(encoded, [n, h, w, -1])\n\n        if self.rep_size:\n            rep_size = self.width if self.rep_size is True else self.rep_size\n            hid = nn.Dense(rep_size, dtype=self.dtype_mm, name=\"pre_logits\")\n            # NOTE: In the past we did not include tanh in pre_logits.\n            # For few-shot, it should not matter much, as it whitens anyways.\n            x_2d = nn.tanh(hid(x_2d))\n            x = nn.tanh(hid(x))\n\n        out[\"pre_logits_2d\"] = x_2d\n        out[\"pre_logits\"] = x\n\n        if self.num_classes:\n            kw = {\"kernel_init\": nn.initializers.zeros} if self.head_zeroinit else {}\n            head = nn.Dense(self.num_classes, dtype=self.dtype_mm, name=\"head\", **kw)\n            x_2d = out[\"logits_2d\"] = head(x_2d)\n            x = out[\"logits\"] = head(x)\n\n        return x, out\n\n\ndef Module(num_classes=None, *, variant=None, **kw):  # pylint: disable=invalid-name  # noqa: N802\n    \"\"\"Factory function, because linen really don't like what I'm doing!\"\"\"\n    return _Module(num_classes, **{**decode_variant(variant), **kw})\n\n\ndef decode_variant(variant):\n    \"\"\"Converts a string like \"B\" or \"B/32\" into a params dict.\"\"\"\n    if variant is None:\n        return {}\n\n    v, patch = variant, {}\n    if \"/\" in variant:\n        v, patch = variant.split(\"/\")\n        patch = {\"patch_size\": (int(patch), int(patch))}\n\n    return {\n        # pylint:disable=line-too-long\n        # Reference: Table 2 of https://arxiv.org/abs/2106.04560.\n        \"width\": {\n            \"mu\": 32,\n            \"Ti\": 192,\n            \"S\": 384,\n            \"M\": 512,\n            \"B\": 768,\n            \"L\": 1024,\n            \"So400m\": 1152,\n            \"H\": 1280,\n            \"g\": 1408,\n            \"g-opt\": 1536,\n            \"G\": 1664,\n            \"G-opt\": 1536,\n            \"e\": 1792,\n        }[v],\n        \"depth\": {\n            \"mu\": 1,\n            \"Ti\": 12,\n            \"S\": 12,\n            \"M\": 12,\n            \"B\": 12,\n            \"L\": 24,\n            \"So400m\": 27,\n            \"H\": 32,\n            \"g\": 40,\n            \"g-opt\": 40,\n            \"G\": 48,\n            \"G-opt\": 48,\n            \"e\": 56,\n        }[v],\n        \"mlp_dim\": {\n            \"mu\": 128,\n            \"Ti\": 768,\n            \"S\": 1536,\n            \"M\": 2048,\n            \"B\": 3072,\n            \"L\": 4096,\n            \"So400m\": 4304,\n            \"H\": 5120,\n            \"g\": 6144,\n            \"g-opt\": 6144,\n            \"G\": 8192,\n            \"G-opt\": 8192,\n            \"e\": 15360,\n        }[v],\n        \"num_heads\": {\n            \"mu\": 2,\n            \"Ti\": 3,\n            \"S\": 6,\n            \"M\": 8,\n            \"B\": 12,\n            \"L\": 16,\n            \"So400m\": 16,\n            \"H\": 16,\n            \"g\": 16,\n            \"g-opt\": 16,\n            \"G\": 16,\n            \"G-opt\": 16,\n            \"e\": 16,\n        }[v],\n        # pylint:enable=line-too-long\n        **patch,\n    }\n"
  },
  {
    "path": "src/openpi/models/tokenizer.py",
    "content": "import logging\nimport os\n\nimport jax\nimport numpy as np\nimport orbax.checkpoint as ocp\nimport sentencepiece\nfrom transformers import AutoProcessor\n\nimport openpi.models.utils.fsq_tokenizer as fsq_tokenizer\nimport openpi.shared.download as download\n\n\nclass PaligemmaTokenizer:\n    def __init__(self, max_len: int = 48):\n        self._max_len = max_len\n\n        path = download.maybe_download(\"gs://big_vision/paligemma_tokenizer.model\", gs={\"token\": \"anon\"})\n        with path.open(\"rb\") as f:\n            self._tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())\n\n    def tokenize(self, prompt: str, state: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]:\n        cleaned_text = prompt.strip().replace(\"_\", \" \").replace(\"\\n\", \" \")\n        if state is not None:\n            # This is the Pi05 format, where the state is part of the discrete language input.\n            discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1\n            state_str = \" \".join(map(str, discretized_state))\n            full_prompt = f\"Task: {cleaned_text}, State: {state_str};\\nAction: \"\n            tokens = self._tokenizer.encode(full_prompt, add_bos=True)\n        else:\n            # This is the Pi0 format, where the state is part of the continuous action expert input.\n            # tokenize \"\\n\" separately as the \"start of answer\" token\n            tokens = self._tokenizer.encode(cleaned_text, add_bos=True) + self._tokenizer.encode(\"\\n\")\n        tokens_len = len(tokens)\n        if tokens_len < self._max_len:\n            padding = [False] * (self._max_len - tokens_len)\n            mask = [True] * tokens_len + padding\n            tokens = tokens + padding\n        else:\n            if len(tokens) > self._max_len:\n                logging.warning(\n                    f\"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. \"\n                    \"Consider increasing the `max_token_len` in your model config if this happens frequently.\"\n                )\n            tokens = tokens[: self._max_len]\n            mask = [True] * self._max_len\n\n        return np.asarray(tokens), np.asarray(mask)\n\n\nclass FASTTokenizer:\n    def __init__(self, max_len: int = 256, fast_tokenizer_path: str = \"physical-intelligence/fast\"):\n        self._max_len = max_len\n\n        # Download base PaliGemma tokenizer\n        path = download.maybe_download(\"gs://big_vision/paligemma_tokenizer.model\", gs={\"token\": \"anon\"})\n        with path.open(\"rb\") as f:\n            self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())\n\n        # Instantiate FAST tokenizer\n        self._fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)\n        self._fast_skip_tokens = 128  # Skip last 128 tokens in PaliGemma vocab since they are special tokens\n\n    def tokenize(\n        self, prompt: str, state: np.ndarray, actions: np.ndarray | None\n    ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:\n        cleaned_text = prompt.lower().strip().replace(\"_\", \" \")\n\n        # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])\n        discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1\n\n        # Convention: prefix includes prompt and string-representation of state, followed by ';'\n        state_str = \" \".join(map(str, discretized_state))\n        prefix = f\"Task: {cleaned_text}, State: {state_str};\\n\"\n        prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)\n\n        if actions is not None:\n            # Tokenize actions with FAST tokenizer --> map to last tokens in PaliGemma vocab\n            action_tokens = self._fast_tokenizer(actions[None])[0]\n            action_tokens_in_pg = self._act_tokens_to_paligemma_tokens(action_tokens)\n\n            # Convention: postfix contains 'Action:' followed by FAST tokens, followed by '|'\n            postfix_tokens = (\n                self._paligemma_tokenizer.encode(\"Action: \")\n                + action_tokens_in_pg.tolist()\n                + self._paligemma_tokenizer.encode(\"|\", add_eos=True)\n            )\n        else:\n            postfix_tokens = []\n\n        # Create output token sequence & masks\n        # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)\n        tokens = prefix_tokens + postfix_tokens\n        token_mask = [True] * len(tokens)\n        ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)\n        loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens)  # Loss on postfix only\n\n        # Pad tokens to max length\n        tokens_len = len(tokens)\n        if tokens_len < self._max_len:\n            padding = [False] * (self._max_len - tokens_len)\n            tokens = tokens + padding\n            token_mask = token_mask + padding\n            ar_mask = ar_mask + padding\n            loss_mask = loss_mask + padding\n        else:\n            if len(tokens) > self._max_len:\n                logging.warning(\n                    f\"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. \"\n                    \"Consider increasing the `max_token_len` in your model config if this happens frequently.\"\n                )\n            tokens = tokens[: self._max_len]\n            token_mask = token_mask[: self._max_len]\n            ar_mask = ar_mask[: self._max_len]\n            loss_mask = loss_mask[: self._max_len]\n\n        return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)\n\n    def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:\n        # Decode predicted output tokens\n        decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())\n\n        # Extract actions from FAST model outputs\n        if \"Action: \" not in decoded_tokens:\n            return np.zeros((action_horizon, action_dim), dtype=np.float32)\n\n        # Extract actions from decoded tokens\n        raw_action_tokens = np.array(\n            self._paligemma_tokenizer.encode(decoded_tokens.split(\"Action: \")[1].split(\"|\")[0].strip())\n        )\n        action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)\n        return self._fast_tokenizer.decode(\n            [action_tokens.tolist()], time_horizon=action_horizon, action_dim=action_dim\n        )[0]\n\n    def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:\n        if isinstance(tokens, list):\n            tokens = np.array(tokens)\n        return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens\n\n\n###########################################################################\n## The tokenizers below are used for RoboArena baseline implementations. ##\n## They are *not* used for pi0-style models.                             ##\n###########################################################################\n\n\nclass BinningTokenizer:\n    \"\"\"\n    Standard RT-2 / OpenVLA style binning tokenizer.\n    \"\"\"\n\n    def __init__(self, max_len: int = 256, n_bins: int = 256):\n        self._max_len = max_len\n        self._n_bins = n_bins\n\n        # Download base PaliGemma tokenizer\n        path = download.maybe_download(\"gs://big_vision/paligemma_tokenizer.model\", gs={\"token\": \"anon\"})\n        with path.open(\"rb\") as f:\n            self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())\n\n        self._fast_skip_tokens = 128  # Skip last 128 tokens in PaliGemma vocab since they are special tokens\n\n    def tokenize(\n        self, prompt: str, state: np.ndarray, actions: np.ndarray | None\n    ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:\n        \"\"\"Tokenize a prompt and state into a sequence of tokens.\n\n        Args:\n            prompt: The text prompt to tokenize.\n            state: The state array to discretize and tokenize.\n            actions: Must be None. Action encoding is not currently supported.\n\n        Returns:\n            A tuple of (tokens, token_mask, ar_mask, targets).\n\n        Raises:\n            NotImplementedError: If actions is not None.\n        \"\"\"\n        cleaned_text = prompt.lower().strip().replace(\"_\", \" \")\n\n        # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])\n        discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1\n\n        # Convention: prefix includes prompt and string-representation of state, followed by ';'\n        state_str = \" \".join(map(str, discretized_state))\n        prefix = f\"Task: {cleaned_text}, State: {state_str};\\n\"\n        prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)\n\n        if actions is not None:\n            raise NotImplementedError(\"BinningTokenizer does not support encoding actions atm (only for inference use)\")\n        postfix_tokens = []\n\n        # Create output token sequence & masks\n        # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)\n        tokens = prefix_tokens + postfix_tokens\n        token_mask = [True] * len(tokens)\n        ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)\n        loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens)  # Loss on postfix only\n\n        # Pad tokens to max length\n        tokens_len = len(tokens)\n        if tokens_len < self._max_len:\n            padding = [False] * (self._max_len - tokens_len)\n            tokens = tokens + padding\n            token_mask = token_mask + padding\n            ar_mask = ar_mask + padding\n            loss_mask = loss_mask + padding\n        else:\n            if len(tokens) > self._max_len:\n                logging.warning(\n                    f\"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. \"\n                    \"Consider increasing the `max_token_len` in your model config if this happens frequently.\"\n                )\n            tokens = tokens[: self._max_len]\n            token_mask = token_mask[: self._max_len]\n            ar_mask = ar_mask[: self._max_len]\n            loss_mask = loss_mask[: self._max_len]\n\n        return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)\n\n    def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:\n        # Decode predicted output tokens\n        decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())\n\n        # Extract actions from FAST model outputs\n        if \"Action: \" not in decoded_tokens:\n            return np.zeros((action_horizon, action_dim), dtype=np.float32)\n\n        # Extract actions from decoded tokens\n        raw_action_tokens = np.array(\n            self._paligemma_tokenizer.encode(decoded_tokens.split(\"Action: \")[1].split(\"|\")[0].strip())\n        )\n        action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)\n        if len(action_tokens) < action_horizon * action_dim:\n            return np.zeros([action_horizon, action_dim], dtype=np.float32)\n        action_tokens = action_tokens[: (action_horizon * action_dim)].reshape([action_horizon, action_dim])\n        return action_tokens / self._n_bins * 2 - 1\n\n    def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:\n        if isinstance(tokens, list):\n            tokens = np.array(tokens)\n        return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens\n\n\nclass FSQTokenizer:\n    \"\"\"\n    FSQ tokenizer from the FAST paper baselines.\n    \"\"\"\n\n    def __init__(self, max_len: int = 256, fsq_tokenizer_path: str | None = None):\n        self._max_len = max_len\n\n        assert fsq_tokenizer_path is not None, \"fsq_tokenizer_path must be provided\"\n        # Download tokenizer\n        path = download.maybe_download(fsq_tokenizer_path)\n        tok_path = os.path.join(path, os.listdir(path)[0])\n\n        # Split step from path\n        step = int(tok_path.split(\"/\")[-1])\n        base_path = tok_path.rsplit(\"/\", 1)[0]\n\n        mgr = ocp.CheckpointManager(\n            base_path,\n            item_handlers={\n                \"params\": ocp.StandardCheckpointHandler(),\n                \"opt_state\": ocp.StandardCheckpointHandler(),\n                \"config\": ocp.JsonCheckpointHandler(),\n            },\n            options=ocp.CheckpointManagerOptions(max_to_keep=1),\n        )\n\n        try:\n            restored = mgr.restore(\n                step, args=ocp.args.Composite(config=ocp.args.JsonRestore(), params=ocp.args.StandardRestore())\n            )\n            config = restored[\"config\"]\n            self._params = restored[\"params\"]\n            self._fsq_tokenizer = fsq_tokenizer.FsqAttentionTokenizer(**config)\n        except Exception as e:\n            raise RuntimeError(\n                f\"Failed to load FSQ tokenizer checkpoint from {fsq_tokenizer_path}. Error: {e!s}\"\n            ) from e\n\n        # Compile tokenize and detokenize functions\n        self._tokenize_fn = jax.jit(\n            lambda params, x: self._fsq_tokenizer.apply({\"params\": params}, x, method=self._fsq_tokenizer.tokenize)\n        )\n        self._detokenize_fn = jax.jit(\n            lambda params, x: self._fsq_tokenizer.apply({\"params\": params}, x, method=self._fsq_tokenizer.detokenize)\n        )\n\n        # Download base PaliGemma tokenizer\n        path = download.maybe_download(\"gs://big_vision/paligemma_tokenizer.model\", gs={\"token\": \"anon\"})\n        with path.open(\"rb\") as f:\n            self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())\n\n        self._fast_skip_tokens = 128  # Skip last 128 tokens in PaliGemma vocab since they are special tokens\n\n    def tokenize(\n        self, prompt: str, state: np.ndarray, actions: np.ndarray | None\n    ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:\n        cleaned_text = prompt.lower().strip().replace(\"_\", \" \")\n\n        # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])\n        discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1\n\n        # Convention: prefix includes prompt and string-representation of state, followed by ';'\n        state_str = \" \".join(map(str, discretized_state))\n        prefix = f\"Task: {cleaned_text}, State: {state_str};\\n\"\n        prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)\n\n        if actions is not None:\n            raise NotImplementedError(\"FSQTokenizer does not support encoding actions atm (only for inference use)\")\n        postfix_tokens = []\n\n        # Create output token sequence & masks\n        # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)\n        tokens = prefix_tokens + postfix_tokens\n        token_mask = [True] * len(tokens)\n        ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)\n        loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens)  # Loss on postfix only\n\n        # Pad tokens to max length\n        tokens_len = len(tokens)\n        if tokens_len < self._max_len:\n            padding = [False] * (self._max_len - tokens_len)\n            tokens = tokens + padding\n            token_mask = token_mask + padding\n            ar_mask = ar_mask + padding\n            loss_mask = loss_mask + padding\n        else:\n            if len(tokens) > self._max_len:\n                logging.warning(\n                    f\"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. \"\n                    \"Consider increasing the `max_token_len` in your model config if this happens frequently.\"\n                )\n            tokens = tokens[: self._max_len]\n            token_mask = token_mask[: self._max_len]\n            ar_mask = ar_mask[: self._max_len]\n            loss_mask = loss_mask[: self._max_len]\n\n        return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)\n\n    def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:\n        # Decode predicted output tokens\n        decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())\n\n        # Extract actions from FAST model outputs\n        if \"Action: \" not in decoded_tokens:\n            return np.zeros((action_horizon, action_dim), dtype=np.float32)\n\n        # Extract actions from decoded tokens\n        raw_action_tokens = np.array(\n            self._paligemma_tokenizer.encode(decoded_tokens.split(\"Action: \")[1].split(\"|\")[0].strip())\n        )\n        action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)\n        try:\n            # Move computation to CPU and compile on-demand\n            device = jax.devices(\"cpu\")[0]\n            with jax.default_device(device):\n                detok_act = self._detokenize_fn(self._params, action_tokens[None, ...])[0]\n            return detok_act[: action_horizon * action_dim].reshape([action_horizon, action_dim])\n        except Exception as e:\n            logging.warning(f\"Error decoding FSQ: {e}\")\n            return np.zeros((action_horizon, action_dim))\n\n    def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:\n        if isinstance(tokens, list):\n            tokens = np.array(tokens)\n        return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens\n"
  },
  {
    "path": "src/openpi/models/tokenizer_test.py",
    "content": "import numpy as np\n\nfrom openpi.models import tokenizer as _tokenizer\n\n\ndef test_tokenize():\n    tokenizer = _tokenizer.PaligemmaTokenizer(max_len=10)\n    tokens, masks = tokenizer.tokenize(\"Hello, world!\")\n\n    assert tokens.shape == (10,)\n    assert masks.shape == (10,)\n\n\ndef test_fast_tokenizer():\n    prompt = \"Hello, world!\"\n    state = np.random.rand(5).astype(np.float32)\n    action = np.random.rand(3, 2).astype(np.float32)\n    tokenizer = _tokenizer.FASTTokenizer(max_len=256)\n    tokens, token_masks, ar_masks, loss_masks = tokenizer.tokenize(prompt, state, action)\n\n    assert tokens.shape == (256,)\n    assert token_masks.shape == (256,)\n    assert ar_masks.shape == (256,)\n    assert loss_masks.shape == (256,)\n\n    act = tokenizer.extract_actions(tokens, 3, 2)\n    assert act.shape == (3, 2)\n"
  },
  {
    "path": "src/openpi/models/utils/fsq_tokenizer.py",
    "content": "import math\nfrom typing import Any, Literal\n\nimport chex\nfrom einops import einops\nfrom flax import linen as nn\nfrom flax.linen.module import Module\nfrom flax.linen.module import compact\nfrom flax.struct import dataclass\nfrom flax.typing import Array\nimport jax\nimport jax.numpy as jnp\n\n\nclass FsqCodebook(nn.Module):\n    input_dim: int\n    target_codebook_size: int\n    codebook_type: Literal[\"fsq\", \"lfq\"]\n\n    _bins_per_dim: tuple[int] | None = None\n\n    @property\n    def bins_per_dim(self) -> tuple[int]:\n        if self._bins_per_dim is not None:\n            return self._bins_per_dim\n\n        if self.codebook_type == \"fsq\":\n            return self._get_bins_fsq(self.target_codebook_size)\n        elif self.codebook_type == \"lfq\":  # noqa: RET505\n            return self._get_bins_lfq(self.target_codebook_size)\n        elif self.codebook_type == \"custom\":\n            return self._get_bins_custom(self.target_codebook_size)\n        else:\n            raise ValueError(f\"Codebook type {self.codebook_type} not supported.\")\n\n    @property\n    def place_values(self) -> jnp.ndarray:\n        place_values = [1]\n        for b in self.bins_per_dim[:-1]:\n            place_values.append(place_values[-1] * b)\n        return jnp.array(place_values)\n\n    @staticmethod\n    def _get_bins_fsq(target_codebook_size: int) -> tuple[int]:\n        \"\"\"\n        Get bins per dimension based on codebook size, from the original FSQ paper.\n        \"\"\"\n        if target_codebook_size == 2**8:\n            return (8, 6, 5)\n        elif target_codebook_size == 2**10:  # noqa: RET505\n            return (8, 5, 5, 5)\n        elif target_codebook_size == 2**12:\n            return (7, 5, 5, 5, 5)\n        elif target_codebook_size == 2**14:\n            return (8, 8, 8, 6, 5)\n        elif target_codebook_size == 2**16:\n            return (8, 8, 8, 5, 5, 5)\n        else:\n            raise ValueError(f\"Codebook size {target_codebook_size} not supported.\")\n\n    @staticmethod\n    def _get_bins_custom(target_codebook_size: int) -> tuple[int]:\n        if target_codebook_size == 2**8:\n            return (16, 16)\n        elif target_codebook_size == 2**10:  # noqa: RET505\n            return (32, 32)\n        elif target_codebook_size == 2**12:\n            return (64, 64)\n        elif target_codebook_size == 2**14:\n            return (128, 128)\n        elif target_codebook_size == 2**16:\n            return (256, 256)\n        return None\n\n    @staticmethod\n    def _get_bins_lfq(target_codebook_size: int) -> tuple[int]:\n        \"\"\"\n        Get bins per dimension according to the Lookup-Free Quantization paper (2 bins per dimension)\n        \"\"\"\n        assert target_codebook_size & (target_codebook_size - 1) == 0, \"Codebook size should be a power of two for LFQ\"\n\n        return (2,) * int(math.log2(target_codebook_size))\n\n    def setup(self):\n        self.proj_down = nn.Dense(len(self.bins_per_dim))\n        self.proj_up = nn.Dense(self.input_dim)\n\n    def __call__(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:\n        tokens, z = self.encode(inputs)\n        output = self.decode(tokens, z_grad=z)\n        return tokens, output\n\n    def encode(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:\n        bases = jnp.array(self.bins_per_dim)\n\n        x = self.proj_down(inputs)\n        z = jnp.tanh(x)\n\n        # Quantize\n        digits = jnp.round((z + 1) * (bases - 1) / 2).astype(jnp.int32)\n        tokens = self.undigitize(digits)\n\n        return tokens, z\n\n    def decode(self, tokens: jnp.ndarray, z_grad: jax.Array | None = None) -> jnp.ndarray:\n        bases = jnp.array(self.bins_per_dim)\n        digits = self.digitize(tokens)\n\n        z_q = digits / (bases - 1) * 2 - 1\n\n        if z_grad is not None:\n            chex.assert_equal_shape([z_q, z_grad])\n            z_q = jax.lax.stop_gradient(z_q - z_grad) + z_grad\n\n        return self.proj_up(z_q)\n\n    def undigitize(self, digits: jnp.ndarray) -> jnp.ndarray:\n        return jnp.sum(digits * jnp.array(self.place_values), axis=-1)\n\n    def digitize(self, tokens: jnp.ndarray) -> jnp.ndarray:\n        return (tokens[..., None] // jnp.array(self.place_values)) % jnp.array(self.bins_per_dim)\n\n    @property\n    def vocab_size(self) -> int:\n        return math.prod(self.bins_per_dim)\n\n\nclass ResNetDownBlock(nn.Module):\n    stride: int = 1\n    n_filters: int = 64\n    dropout_rate: float = 0.0\n    group_size: int = 32\n\n    @nn.compact\n    def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray:\n        skip = x\n\n        if self.stride > 1 or x.shape[-1] != self.n_filters:\n            skip = nn.Conv(self.n_filters, (self.stride,), (self.stride,), \"SAME\")(skip)\n\n        x = nn.Conv(self.n_filters, (3,), (self.stride,), \"SAME\")(x)\n        x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)\n        x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)\n        x = nn.relu(x)\n        x = nn.Conv(self.n_filters, (3,), (1,), \"SAME\")(x)\n\n        return skip + x\n\n\nclass ResNetUpBlock(nn.Module):\n    stride: int = 1\n    n_filters: int = 64\n    dropout_rate: float = 0.0\n    group_size: int = 32\n\n    @nn.compact\n    def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray:\n        skip = x\n\n        if self.stride > 1:\n            skip = nn.ConvTranspose(self.n_filters, (self.stride,), (self.stride,), \"SAME\")(skip)\n\n        x = nn.ConvTranspose(self.n_filters, (3,), (self.stride,), \"SAME\")(x)\n        x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)\n        x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)\n        x = nn.relu(x)\n        x = nn.ConvTranspose(self.n_filters, (3,), (1,), \"SAME\")(x)\n\n        return skip + x\n\n\n@dataclass\nclass LfqCodebookOutput:\n    tokens: jnp.ndarray\n    z: jnp.ndarray\n    z_q: jnp.ndarray\n    token_log_probs: jnp.ndarray\n    commit_loss: jnp.ndarray\n\n\nclass LookupFreeQuantization(nn.Module):\n    num_dims: int\n    latent_dim: int\n\n    def setup(self):\n        self.codebook = jnp.array([-1, 1])\n        self.activation = nn.tanh\n\n        self.project_down = nn.Dense(self.num_dims)\n        self.project_up = nn.Dense(self.latent_dim)\n\n    def encode(self, z: jnp.ndarray) -> jnp.ndarray:\n        z = self.project_down(z)\n        token_squared_distances = jnp.square(z[..., None] - self.codebook)\n        token_bits = jnp.argmin(token_squared_distances, axis=-1)\n        return jnp.sum(token_bits * (2 ** jnp.arange(self.num_dims)), axis=-1)\n\n    def decode(self, tokens: jnp.ndarray) -> jnp.ndarray:\n        token_bits = (tokens[..., None] & (2 ** jnp.arange(self.num_dims))).astype(jnp.int32)\n        return self.project_up(self.codebook[token_bits])\n\n    def loss(self, x: jnp.ndarray) -> LfqCodebookOutput:\n        z = self.project_down(x)\n        z = self.activation(z)\n\n        token_squared_distances = jnp.square(z[..., None] - self.codebook)\n        tokens = jnp.argmin(token_squared_distances, axis=-1)\n\n        token_bit_log_probs = -token_squared_distances\n        # Compute token log probs for tokens 0..2^num_dims-1 by summing corresponding log-probs\n        token_bit_expansions = jnp.bitwise_and(\n            jnp.arange(2**self.num_dims)[None, :], 2 ** jnp.arange(self.num_dims)[:, None]\n        ).astype(jnp.int32)\n        token_log_probs = (\n            token_bit_log_probs[..., 0] @ (1 - token_bit_expansions)\n            + token_bit_log_probs[..., 1] @ token_bit_expansions\n        )  # (batch_size, num_tokens, 2 ** num_dims)\n        token_log_probs = jax.lax.stop_gradient(jax.nn.log_softmax(token_log_probs, axis=-1))\n        chex.assert_shape(token_log_probs, (*x.shape[:-1], 2**self.num_dims))\n\n        z_q = self.codebook[tokens]\n        commit_loss = jnp.square(z - z_q).mean()\n        z_q = jax.lax.stop_gradient(z_q - z) + z\n\n        z_q = self.project_up(z_q)\n        z = self.project_up(z)\n\n        tokens = jnp.sum(tokens * (len(self.codebook) ** jnp.arange(self.num_dims)), axis=-1)\n        return LfqCodebookOutput(\n            tokens=tokens,\n            z=z,\n            z_q=z_q,\n            token_log_probs=jnp.zeros(()),\n            commit_loss=commit_loss,\n        )\n\n\ndef make_block_causal_attention_matrix(q: jnp.ndarray, k: jnp.ndarray, bs_q: int, bs_k: int) -> jnp.ndarray:\n    return nn.make_attention_mask(q, k, pairwise_fn=lambda x, y: jnp.greater_equal(x // bs_k, y // bs_q))\n\n\nclass GeGLU(Module):\n    \"\"\"Gated Linear Unit with GELU (GeGLU) activation function.\n    GeGLU is a Flax layer that combines a linear transformation with a GELU\n    activation function in a gating mechanism. It is often used in Transformer models\n    to provide non-linear capabilities while preserving a strong linear component.\n\n    Attributes:\n        features: the number of output features (default: None).\n    \"\"\"\n\n    output_dim: int = -1\n\n    @compact\n    def __call__(self, inputs: Array) -> Array:\n        \"\"\"Applies the GeGLU activation to the inputs.\n        Args:\n            inputs: the nd-array to apply the GeGLU activation function to.\n        Returns:\n            The transformed input.\n        \"\"\"\n        output_dim = inputs.shape[-1] if self.output_dim == -1 else self.output_dim\n\n        x = nn.Dense(output_dim * 2)(inputs)\n        x, gate = x[..., :output_dim], x[..., output_dim:]\n        return x * nn.gelu(gate)\n\n\nclass CrossAttentionLayer(nn.Module):\n    dropout_rate: float = 0.0\n    num_heads: int = None\n    causal: bool = False\n    mlp_ratio: float = 4.0\n\n    @nn.compact\n    def __call__(\n        self,\n        x: jnp.ndarray,\n        y: jnp.ndarray,\n        *,\n        mask_self: jnp.ndarray | None = None,\n        mask_cross: jnp.ndarray | None = None,\n        train: bool = True,\n    ) -> jnp.ndarray:\n        d_embed = x.shape[-1]\n        seq_len_q = x.shape[-2]\n        seq_len_k = y.shape[-2]\n\n        if self.causal:\n            # One block size will be 1\n            bs_q = max(seq_len_q // seq_len_k, 1)\n            bs_k = max(seq_len_k // seq_len_q, 1)\n\n            mask_self = nn.make_causal_mask(x[..., 0])\n            mask_cross = make_block_causal_attention_matrix(x[..., 0], y[..., 0], bs_q, bs_k)\n\n        # Self-attention block\n        skip = x\n        x = nn.LayerNorm()(x)\n        x = nn.MultiHeadDotProductAttention(\n            num_heads=self.num_heads or d_embed // 64,\n            dropout_rate=self.dropout_rate,\n            deterministic=not train,\n        )(x, x, x, mask=mask_self)\n        x = skip + x\n\n        # Cross-attention block\n        skip = x\n        x = nn.LayerNorm()(x)\n        x = nn.MultiHeadDotProductAttention(\n            num_heads=self.num_heads or d_embed // 64,\n            dropout_rate=self.dropout_rate,\n            deterministic=not train,\n        )(x, y, y, mask=mask_cross)\n        x = skip + x\n\n        # MLP block\n        skip = x\n        x = nn.LayerNorm()(x)\n        x = nn.Dense(int(d_embed * self.mlp_ratio))(x)\n        x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)\n        x = GeGLU()(x)\n        x = nn.Dense(d_embed)(x)\n        return skip + x\n\n\ndef sinusoidal_pe_init(_, shape: tuple[int, int]) -> jnp.ndarray:\n    seq_len, d_embed = shape\n\n    position = jnp.arange(0, seq_len, 1)\n    div_term = jnp.exp(jnp.arange(0, d_embed, 2) * -(jnp.log(10000.0) / d_embed))\n    return jnp.concatenate(\n        [\n            jnp.sin(position[:, jnp.newaxis] * div_term),\n            jnp.cos(position[:, jnp.newaxis] * div_term),\n        ],\n        axis=-1,\n    )\n\n\nclass TokenizerEncoderDecoder(nn.Module):\n    num_tokens: int\n    num_cross_tokens: int\n    num_layers: int\n    causal: bool\n\n    mlp_ratio: float = 4.0\n    use_state_conditioning: bool = False\n\n    @nn.compact\n    def __call__(\n        self,\n        y: jnp.ndarray,\n        *,\n        train: bool = True,\n        state_conditioning: jnp.ndarray | None = None,\n        mask: jnp.ndarray | None = None,\n    ) -> jnp.ndarray:\n        x = self.param(\"q_embed\", sinusoidal_pe_init, (self.num_tokens, y.shape[-1]))\n        x = jax.numpy.broadcast_to(x, y.shape[:-2] + x.shape[-2:])\n\n        if mask is not None:\n            # mask is (batch_dims..., num_cross_tokens)\n            chex.assert_equal_shape([y[..., 0], mask])\n            attn_mask = einops.repeat(mask, \"... kv -> ... 1 q kv\", q=self.num_tokens)\n        else:\n            attn_mask = jnp.ones((*y.shape[:-2], 1, self.num_tokens, self.num_cross_tokens))\n\n        if self.use_state_conditioning:\n            assert state_conditioning is not None, \"State conditioning is required for this model.\"\n            state_embed = nn.Dense(y.shape[-1], name=\"state_proj\")(state_conditioning)[..., None, :]\n            y = jnp.concatenate([y, state_embed], axis=-2)\n            attn_mask = jnp.concatenate([attn_mask, jnp.ones_like(attn_mask[..., 0:1])], axis=-1)\n\n        y = y + self.param(\"y_pos_enc\", sinusoidal_pe_init, y.shape[-2:])\n\n        for _ in range(self.num_layers):\n            x = CrossAttentionLayer(causal=self.causal, mlp_ratio=self.mlp_ratio)(\n                x, y, train=train, mask_self=None, mask_cross=attn_mask\n            )\n\n        return x\n\n\nclass FsqAttentionTokenizer(nn.Module):\n    embed_dim: int\n    data_dim: int\n    data_horizon: int\n    num_tokens: int\n    num_layers: int\n    target_codebook_size: int\n    causal: bool = False\n    mlp_ratio: float = 2.0\n\n    bound: float | None = None\n\n    use_state_conditioning: bool = False\n\n    @property\n    def vocab_size(self) -> int:\n        return math.prod(FsqCodebook._get_bins_fsq(self.target_codebook_size))  # noqa: SLF001\n\n    def setup(self):\n        self.proj = nn.Dense(self.embed_dim)\n        self.encoder = TokenizerEncoderDecoder(\n            num_tokens=self.num_tokens,\n            num_cross_tokens=self.data_horizon,\n            num_layers=self.num_layers,\n            causal=self.causal,\n            use_state_conditioning=self.use_state_conditioning,\n            mlp_ratio=self.mlp_ratio,\n        )\n        self.codebook = FsqCodebook(\n            input_dim=self.embed_dim,\n            target_codebook_size=self.target_codebook_size,\n            codebook_type=\"custom\",\n        )\n        self.decoder = TokenizerEncoderDecoder(\n            num_tokens=self.data_horizon,\n            num_cross_tokens=self.num_tokens,\n            num_layers=self.num_layers,\n            causal=self.causal,\n            use_state_conditioning=self.use_state_conditioning,\n            mlp_ratio=self.mlp_ratio,\n        )\n\n        self.proj_mean = nn.Dense(self.data_dim)\n        self.out_scale = self.param(\"out_scale\", lambda _: jnp.full((), 1.0))\n\n    def tokenize(\n        self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = False\n    ) -> tuple[jnp.ndarray, jnp.ndarray]:\n        if self.bound is not None:\n            action = jnp.clip(action, -self.bound, self.bound)\n\n        x = self.proj(action)\n        x = self.encoder(x, train=train, state_conditioning=obs)\n\n        return self.codebook.encode(x)\n\n    def detokenize(self, tokens: jnp.ndarray, *, obs: jnp.ndarray | None = None) -> jnp.ndarray:\n        x = self.decoder(self.codebook.decode(tokens), state_conditioning=obs)\n        mean = self.proj_mean(x)\n        return mean * self.out_scale\n\n    def loss(\n        self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = True\n    ) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:\n        # Encode\n        x = self.proj(action)\n        z = self.encoder(x, train=train, state_conditioning=obs)\n\n        # Quantize\n        tokens, z = self.codebook(z)\n\n        # Decode\n        x = self.decoder(z, train=train, state_conditioning=obs)\n        mean = self.proj_mean(x) * self.out_scale\n\n        mse = jnp.mean(jnp.square(action - mean))\n        mae = jnp.mean(jnp.abs(action - mean))\n\n        return mse, {\n            \"mse\": mse,\n            \"mae\": mae,\n        }\n\n    def __call__(self, *args: Any, **kwargs: Any) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:\n        \"\"\"\n        Dummy for .init\n        \"\"\"\n        return self.loss(*args, **kwargs)\n"
  },
  {
    "path": "src/openpi/models/vit.py",
    "content": "# Copyright 2024 Google LLC.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"ViT implementation adapted from https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py.\"\"\"\n\nfrom collections.abc import Callable\nfrom typing import Any\n\nimport flax.linen as nn\nimport jax\nimport jax.numpy as jnp\n\nfrom openpi.models import resnet as models_resnet\n\nArray = Any\nPRNGKey = Any\nShape = tuple[int]\nDtype = Any\n\n\nclass IdentityLayer(nn.Module):\n    \"\"\"Identity layer, convenient for giving a name to an array.\"\"\"\n\n    @nn.compact\n    def __call__(self, x):\n        return x\n\n\nclass AddPositionEmbs(nn.Module):\n    \"\"\"Adds learned positional embeddings to the inputs.\n\n    Attributes:\n      posemb_init: positional embedding initializer.\n    \"\"\"\n\n    posemb_init: Callable[[PRNGKey, Shape, Dtype], Array]\n    param_dtype: Dtype = jnp.float32\n\n    @nn.compact\n    def __call__(self, inputs):\n        \"\"\"Applies the AddPositionEmbs module.\n\n        Args:\n          inputs: Inputs to the layer.\n\n        Returns:\n          Output tensor with shape `(bs, timesteps, in_dim)`.\n        \"\"\"\n        # inputs.shape is (batch_size, seq_len, emb_dim).\n        assert inputs.ndim == 3, f\"Number of dimensions should be 3, but it is: {inputs.ndim}\"\n        pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])\n        pe = self.param(\"pos_embedding\", self.posemb_init, pos_emb_shape, self.param_dtype)\n        return inputs + pe\n\n\nclass MlpBlock(nn.Module):\n    \"\"\"Transformer MLP / feed-forward block.\"\"\"\n\n    mlp_dim: int\n    dtype: Dtype = jnp.float32\n    param_dtype: Dtype = jnp.float32\n    out_dim: int | None = None\n    dropout_rate: float = 0.1\n    kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_uniform()\n    bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.normal(stddev=1e-6)\n\n    @nn.compact\n    def __call__(self, inputs, *, deterministic):\n        \"\"\"Applies Transformer MlpBlock module.\"\"\"\n        actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim\n        x = nn.Dense(\n            features=self.mlp_dim,\n            dtype=self.dtype,\n            param_dtype=self.param_dtype,\n            kernel_init=self.kernel_init,\n            bias_init=self.bias_init,\n        )(  # pytype: disable=wrong-arg-types\n            inputs\n        )\n        x = nn.gelu(x)\n        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)\n        output = nn.Dense(\n            features=actual_out_dim,\n            dtype=self.dtype,\n            param_dtype=self.param_dtype,\n            kernel_init=self.kernel_init,\n            bias_init=self.bias_init,\n        )(  # pytype: disable=wrong-arg-types\n            x\n        )\n        return nn.Dropout(rate=self.dropout_rate)(output, deterministic=deterministic)\n\n\nclass Encoder1DBlock(nn.Module):\n    \"\"\"Transformer encoder layer.\n\n    Attributes:\n      inputs: input data.\n      mlp_dim: dimension of the mlp on top of attention block.\n      dtype: the dtype of the computation (default: float32).\n      dropout_rate: dropout rate.\n      attention_dropout_rate: dropout for attention heads.\n      deterministic: bool, deterministic or not (to apply dropout).\n      num_heads: Number of heads in nn.MultiHeadDotProductAttention\n    \"\"\"\n\n    mlp_dim: int\n    num_heads: int\n    dtype: Dtype = jnp.float32\n    dropout_rate: float = 0.1\n    attention_dropout_rate: float = 0.1\n\n    @nn.compact\n    def __call__(self, inputs, deterministic):\n        \"\"\"Applies Encoder1DBlock module.\n\n        Args:\n          inputs: Inputs to the layer.\n          deterministic: Dropout will not be applied when set to true.\n\n        Returns:\n          output after transformer encoder block.\n        \"\"\"\n\n        # Attention block.\n        assert inputs.ndim == 3, f\"Expected (batch, seq, hidden) got {inputs.shape}\"\n        x = nn.LayerNorm(dtype=self.dtype)(inputs)\n        x = nn.MultiHeadDotProductAttention(\n            dtype=self.dtype,\n            kernel_init=nn.initializers.xavier_uniform(),\n            broadcast_dropout=False,\n            deterministic=deterministic,\n            dropout_rate=self.attention_dropout_rate,\n            num_heads=self.num_heads,\n            # why isn't this true by default???\n            force_fp32_for_softmax=True,\n        )(x, x)\n        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)\n        x = x + inputs\n\n        # MLP block.\n        y = nn.LayerNorm(dtype=self.dtype)(x)\n        y = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)(\n            y, deterministic=deterministic\n        )\n\n        return x + y, None\n\n\nclass Encoder(nn.Module):\n    \"\"\"Transformer Model Encoder for sequence to sequence translation.\n\n    Attributes:\n      num_layers: number of layers\n      mlp_dim: dimension of the mlp on top of attention block\n      num_heads: Number of heads in nn.MultiHeadDotProductAttention\n      dropout_rate: dropout rate.\n      attention_dropout_rate: dropout rate in self attention.\n    \"\"\"\n\n    dtype: jax.typing.DTypeLike\n    num_layers: int\n    mlp_dim: int\n    num_heads: int\n    dropout_rate: float = 0.1\n    attention_dropout_rate: float = 0.1\n    add_position_embedding: bool = True\n\n    @nn.compact\n    def __call__(self, x, *, train):\n        \"\"\"Applies Transformer model on the inputs.\n\n        Args:\n          x: Inputs to the layer.\n          train: Set to `True` when training.\n\n        Returns:\n          output of a transformer encoder.\n        \"\"\"\n        assert x.ndim == 3  # (batch, len, emb)\n\n        if self.add_position_embedding:\n            x = AddPositionEmbs(\n                posemb_init=nn.initializers.normal(stddev=0.02),  # from BERT.\n                name=\"posembed_input\",\n            )(x)\n            x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)\n\n        x = x.astype(self.dtype)\n        # Input Encoder\n        block = nn.remat(Encoder1DBlock, prevent_cse=False, static_argnums=(2,))\n        x, _ = nn.scan(\n            block,\n            variable_axes={\"params\": 0},\n            split_rngs={\"params\": True, \"dropout\": True},\n            in_axes=nn.broadcast,\n            length=self.num_layers,\n        )(\n            name=\"encoderblock\",\n            mlp_dim=self.mlp_dim,\n            dropout_rate=self.dropout_rate,\n            attention_dropout_rate=self.attention_dropout_rate,\n            dtype=self.dtype,\n            num_heads=self.num_heads,\n        )(x, not train)\n        return nn.LayerNorm(name=\"encoder_norm\", dtype=self.dtype)(x)\n\n\nclass VisionTransformer(nn.Module):\n    \"\"\"VisionTransformer.\"\"\"\n\n    dtype: jax.typing.DTypeLike\n    num_classes: int\n    patches: Any\n    transformer: Any\n    hidden_size: int\n    resnet: Any | None = None\n    representation_size: int | None = None\n    classifier: str = \"token\"\n    head_bias_init: float = 0.0\n    encoder: type[nn.Module] = Encoder\n    model_name: str | None = None\n\n    @nn.compact\n    def __call__(self, inputs, *, train):\n        x = inputs\n        # (Possibly partial) ResNet root.\n        if self.resnet is not None:\n            width = int(64 * self.resnet.width_factor)\n\n            # Root block.\n            x = models_resnet.StdConv(\n                features=width, kernel_size=(7, 7), strides=(2, 2), use_bias=False, name=\"conv_root\"\n            )(x)\n            x = nn.GroupNorm(name=\"gn_root\")(x)\n            x = nn.relu(x)\n            x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding=\"SAME\")\n\n            # ResNet stages.\n            if self.resnet.num_layers:\n                x = models_resnet.ResNetStage(\n                    block_size=self.resnet.num_layers[0], nout=width, first_stride=(1, 1), name=\"block1\"\n                )(x)\n                for i, block_size in enumerate(self.resnet.num_layers[1:], 1):\n                    x = models_resnet.ResNetStage(\n                        block_size=block_size, nout=width * 2**i, first_stride=(2, 2), name=f\"block{i + 1}\"\n                    )(x)\n\n        n, h, w, c = x.shape\n\n        # We can merge s2d+emb into a single conv; it's the same.\n        x = nn.Conv(\n            features=self.hidden_size,\n            kernel_size=self.patches.size,\n            strides=self.patches.size,\n            padding=\"VALID\",\n            name=\"embedding\",\n        )(x)\n\n        # Here, x is a grid of embeddings.\n\n        # (Possibly partial) Transformer.\n        if self.transformer is not None:\n            n, h, w, c = x.shape\n            x = jnp.reshape(x, [n, h * w, c])\n\n            # If we want to add a class token, add it here.\n            if self.classifier in [\"token\", \"token_unpooled\"]:\n                cls = self.param(\"cls\", nn.initializers.zeros, (1, 1, c))\n                cls = jnp.tile(cls, [n, 1, 1])\n                x = jnp.concatenate([cls, x], axis=1)\n\n            x = self.encoder(name=\"Transformer\", **self.transformer, dtype=self.dtype)(x, train=train)\n\n        if self.classifier == \"token\":\n            x = x[:, 0]\n        elif self.classifier == \"gap\":\n            x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)\n        elif self.classifier in [\"unpooled\", \"token_unpooled\"]:\n            pass\n        else:\n            raise ValueError(f\"Invalid classifier={self.classifier}\")\n\n        if self.representation_size is not None:\n            x = nn.Dense(features=self.representation_size, name=\"pre_logits\")(x)\n            x = nn.tanh(x)\n        else:\n            x = IdentityLayer(name=\"pre_logits\")(x)\n\n        if self.num_classes:\n            x = nn.Dense(\n                features=self.num_classes,\n                name=\"head\",\n                kernel_init=nn.initializers.zeros,\n                bias_init=nn.initializers.constant(self.head_bias_init),\n            )(x)\n        return x\n"
  },
  {
    "path": "src/openpi/models_pytorch/gemma_pytorch.py",
    "content": "from typing import Literal\n\nimport pytest\nimport torch\nfrom torch import nn\nfrom transformers import GemmaForCausalLM\nfrom transformers import PaliGemmaForConditionalGeneration\nfrom transformers.models.auto import CONFIG_MAPPING\nfrom transformers.models.gemma import modeling_gemma\n\n\nclass PaliGemmaWithExpertModel(nn.Module):\n    def __init__(\n        self,\n        vlm_config,\n        action_expert_config,\n        use_adarms=None,\n        precision: Literal[\"bfloat16\", \"float32\"] = \"bfloat16\",\n    ):\n        if use_adarms is None:\n            use_adarms = [False, False]\n        super().__init__()\n\n        vlm_config_hf = CONFIG_MAPPING[\"paligemma\"]()\n        vlm_config_hf._vocab_size = 257152  # noqa: SLF001\n        vlm_config_hf.image_token_index = 257152\n        vlm_config_hf.text_config.hidden_size = vlm_config.width\n        vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim\n        vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads\n        vlm_config_hf.text_config.head_dim = vlm_config.head_dim\n        vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth\n        vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads\n        vlm_config_hf.text_config.hidden_activation = \"gelu_pytorch_tanh\"\n        vlm_config_hf.text_config.torch_dtype = \"float32\"\n        vlm_config_hf.text_config.vocab_size = 257152\n        vlm_config_hf.text_config.use_adarms = use_adarms[0]\n        vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None\n        vlm_config_hf.vision_config.intermediate_size = 4304\n        vlm_config_hf.vision_config.projection_dim = 2048\n        vlm_config_hf.vision_config.projector_hidden_act = \"gelu_fast\"\n        vlm_config_hf.vision_config.torch_dtype = \"float32\"\n\n        action_expert_config_hf = CONFIG_MAPPING[\"gemma\"](\n            head_dim=action_expert_config.head_dim,\n            hidden_size=action_expert_config.width,\n            intermediate_size=action_expert_config.mlp_dim,\n            num_attention_heads=action_expert_config.num_heads,\n            num_hidden_layers=action_expert_config.depth,\n            num_key_value_heads=action_expert_config.num_kv_heads,\n            vocab_size=257152,\n            hidden_activation=\"gelu_pytorch_tanh\",\n            torch_dtype=\"float32\",\n            use_adarms=use_adarms[1],\n            adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,\n        )\n\n        self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)\n        self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)\n        self.gemma_expert.model.embed_tokens = None\n\n        self.to_bfloat16_for_selected_params(precision)\n\n    def to_bfloat16_for_selected_params(self, precision: Literal[\"bfloat16\", \"float32\"] = \"bfloat16\"):\n        if precision == \"bfloat16\":\n            self.to(dtype=torch.bfloat16)\n        elif precision == \"float32\":\n            self.to(dtype=torch.float32)\n            return\n        else:\n            raise ValueError(f\"Invalid precision: {precision}\")\n\n        params_to_keep_float32 = [\n            \"vision_tower.vision_model.embeddings.patch_embedding.weight\",\n            \"vision_tower.vision_model.embeddings.patch_embedding.bias\",\n            \"vision_tower.vision_model.embeddings.position_embedding.weight\",\n            \"input_layernorm\",\n            \"post_attention_layernorm\",\n            \"model.norm\",\n        ]\n\n        for name, param in self.named_parameters():\n            if any(selector in name for selector in params_to_keep_float32):\n                param.data = param.data.to(dtype=torch.float32)\n\n    def embed_image(self, image: torch.Tensor):\n        return self.paligemma.model.get_image_features(image)\n\n    def embed_language_tokens(self, tokens: torch.Tensor):\n        return self.paligemma.language_model.embed_tokens(tokens)\n\n    def forward(\n        self,\n        attention_mask: torch.Tensor | None = None,\n        position_ids: torch.LongTensor | None = None,\n        past_key_values: list[torch.FloatTensor] | pytest.Cache | None = None,\n        inputs_embeds: list[torch.FloatTensor] | None = None,\n        use_cache: bool | None = None,\n        adarms_cond: list[torch.Tensor] | None = None,\n    ):\n        if adarms_cond is None:\n            adarms_cond = [None, None]\n        if inputs_embeds[1] is None:\n            prefix_output = self.paligemma.language_model.forward(\n                inputs_embeds=inputs_embeds[0],\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                past_key_values=past_key_values,\n                use_cache=use_cache,\n                adarms_cond=adarms_cond[0] if adarms_cond is not None else None,\n            )\n            prefix_past_key_values = prefix_output.past_key_values\n            prefix_output = prefix_output.last_hidden_state\n            suffix_output = None\n        elif inputs_embeds[0] is None:\n            suffix_output = self.gemma_expert.model.forward(\n                inputs_embeds=inputs_embeds[1],\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                past_key_values=past_key_values,\n                use_cache=use_cache,\n                adarms_cond=adarms_cond[1] if adarms_cond is not None else None,\n            )\n            suffix_output = suffix_output.last_hidden_state\n            prefix_output = None\n            prefix_past_key_values = None\n        else:\n            models = [self.paligemma.language_model, self.gemma_expert.model]\n            num_layers = self.paligemma.config.text_config.num_hidden_layers\n\n            # Check if gradient checkpointing is enabled for any of the models\n            use_gradient_checkpointing = (\n                hasattr(self.gemma_expert.model, \"gradient_checkpointing\")\n                and self.gemma_expert.model.gradient_checkpointing\n                and self.training\n            ) or (hasattr(self, \"gradient_checkpointing\") and self.gradient_checkpointing and self.training)\n\n            # Force enable gradient checkpointing if we're in training mode and the model supports it\n            if self.training and hasattr(self.gemma_expert.model, \"gradient_checkpointing\"):\n                if not self.gemma_expert.model.gradient_checkpointing:\n                    print(\"Forcing gradient checkpointing to be enabled for Gemma expert model\")\n                    self.gemma_expert.model.gradient_checkpointing = True\n                use_gradient_checkpointing = True\n\n            # Debug gradient checkpointing status\n            if hasattr(self, \"_debug_gc_printed\") and not self._debug_gc_printed:\n                print(f\"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}\")\n                print(f\"Model training mode: {self.training}\")\n                print(\n                    f\"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}\"\n                )\n                if hasattr(self.gemma_expert.model, \"gradient_checkpointing\"):\n                    print(\n                        f\"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}\"\n                    )\n                self._debug_gc_printed = True\n\n            # Define the complete layer computation function for gradient checkpointing\n            def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):\n                models = [self.paligemma.language_model, self.gemma_expert.model]\n\n                query_states = []\n                key_states = []\n                value_states = []\n                gates = []\n                for i, hidden_states in enumerate(inputs_embeds):\n                    layer = models[i].layers[layer_idx]\n                    hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i])  # noqa: PLW2901\n                    gates.append(gate)\n\n                    input_shape = hidden_states.shape[:-1]\n                    hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)\n                    query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n                    key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n                    value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n\n                    query_states.append(query_state)\n                    key_states.append(key_state)\n                    value_states.append(value_state)\n\n                # Concatenate and process attention\n                query_states = torch.cat(query_states, dim=2)\n                key_states = torch.cat(key_states, dim=2)\n                value_states = torch.cat(value_states, dim=2)\n\n                dummy_tensor = torch.zeros(\n                    query_states.shape[0],\n                    query_states.shape[2],\n                    query_states.shape[-1],\n                    device=query_states.device,\n                    dtype=query_states.dtype,\n                )\n                cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)\n                query_states, key_states = modeling_gemma.apply_rotary_pos_emb(\n                    query_states, key_states, cos, sin, unsqueeze_dim=1\n                )\n\n                batch_size = query_states.shape[0]\n                scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling\n\n                # Attention computation\n                att_output, _ = modeling_gemma.eager_attention_forward(\n                    self.paligemma.language_model.layers[layer_idx].self_attn,\n                    query_states,\n                    key_states,\n                    value_states,\n                    attention_mask,\n                    scaling,\n                )\n                # Get head_dim from the current layer, not from the model\n                head_dim = self.paligemma.language_model.layers[layer_idx].self_attn.head_dim\n                att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)\n\n                # Process layer outputs\n                outputs_embeds = []\n                start_pos = 0\n                for i, hidden_states in enumerate(inputs_embeds):\n                    layer = models[i].layers[layer_idx]\n                    end_pos = start_pos + hidden_states.shape[1]\n\n                    if att_output.dtype != layer.self_attn.o_proj.weight.dtype:\n                        att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)\n                    out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])\n\n                    # first residual\n                    out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i])  # noqa: SLF001\n                    after_first_residual = out_emb.clone()\n                    out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])\n                    # Convert to bfloat16 if the next layer (mlp) uses bfloat16\n                    if layer.mlp.up_proj.weight.dtype == torch.bfloat16:\n                        out_emb = out_emb.to(dtype=torch.bfloat16)\n\n                    out_emb = layer.mlp(out_emb)\n                    # second residual\n                    out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate)  # noqa: SLF001\n                    outputs_embeds.append(out_emb)\n                    start_pos = end_pos\n\n                return outputs_embeds\n\n            # Process all layers with gradient checkpointing if enabled\n            for layer_idx in range(num_layers):\n                if use_gradient_checkpointing:\n                    inputs_embeds = torch.utils.checkpoint.checkpoint(\n                        compute_layer_complete,\n                        layer_idx,\n                        inputs_embeds,\n                        attention_mask,\n                        position_ids,\n                        adarms_cond,\n                        use_reentrant=False,\n                        preserve_rng_state=False,\n                    )\n                else:\n                    inputs_embeds = compute_layer_complete(\n                        layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond\n                    )\n\n                # Old code removed - now using compute_layer_complete function above\n\n            # final norm\n            # Define final norm computation function for gradient checkpointing\n            def compute_final_norms(inputs_embeds, adarms_cond):\n                outputs_embeds = []\n                for i, hidden_states in enumerate(inputs_embeds):\n                    out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])\n                    outputs_embeds.append(out_emb)\n                return outputs_embeds\n\n            # Apply gradient checkpointing to final norm if enabled\n            if use_gradient_checkpointing:\n                outputs_embeds = torch.utils.checkpoint.checkpoint(\n                    compute_final_norms, inputs_embeds, adarms_cond, use_reentrant=False, preserve_rng_state=False\n                )\n            else:\n                outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)\n\n            prefix_output = outputs_embeds[0]\n            suffix_output = outputs_embeds[1]\n            prefix_past_key_values = None\n\n        return [prefix_output, suffix_output], prefix_past_key_values\n"
  },
  {
    "path": "src/openpi/models_pytorch/pi0_pytorch.py",
    "content": "import logging\nimport math\n\nimport torch\nfrom torch import Tensor\nfrom torch import nn\nimport torch.nn.functional as F  # noqa: N812\n\nimport openpi.models.gemma as _gemma\nfrom openpi.models_pytorch.gemma_pytorch import PaliGemmaWithExpertModel\nimport openpi.models_pytorch.preprocessing_pytorch as _preprocessing\n\n\ndef get_safe_dtype(target_dtype, device_type):\n    \"\"\"Get a safe dtype for the given device type.\"\"\"\n    if device_type == \"cpu\":\n        # CPU doesn't support bfloat16, use float32 instead\n        if target_dtype == torch.bfloat16:\n            return torch.float32\n        if target_dtype == torch.float64:\n            return torch.float64\n    return target_dtype\n\n\ndef create_sinusoidal_pos_embedding(\n    time: torch.tensor, dimension: int, min_period: float, max_period: float, device=\"cpu\"\n) -> Tensor:\n    \"\"\"Computes sine-cosine positional embedding vectors for scalar positions.\"\"\"\n    if dimension % 2 != 0:\n        raise ValueError(f\"dimension ({dimension}) must be divisible by 2\")\n\n    if time.ndim != 1:\n        raise ValueError(\"The time tensor is expected to be of shape `(batch_size, )`.\")\n\n    dtype = get_safe_dtype(torch.float64, device.type)\n    fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)\n    period = min_period * (max_period / min_period) ** fraction\n\n    # Compute the outer product\n    scaling_factor = 1.0 / period * 2 * math.pi\n    sin_input = scaling_factor[None, :] * time[:, None]\n    return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)\n\n\ndef sample_beta(alpha, beta, bsize, device):\n    alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)\n    beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)\n    dist = torch.distributions.Beta(alpha_t, beta_t)\n    return dist.sample((bsize,))\n\n\ndef make_att_2d_masks(pad_masks, att_masks):\n    \"\"\"Copied from big_vision.\n\n    Tokens can attend to valid inputs tokens which have a cumulative mask_ar\n    smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to\n    setup several types of attention, for example:\n\n      [[1 1 1 1 1 1]]: pure causal attention.\n\n      [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between\n          themselves and the last 3 tokens have a causal attention. The first\n          entry could also be a 1 without changing behaviour.\n\n      [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a\n          block can attend all previous blocks and all tokens on the same block.\n\n    Args:\n      input_mask: bool[B, N] true if its part of the input, false if padding.\n      mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on\n        it and 0 where it shares the same attention mask as the previous token.\n    \"\"\"\n    if att_masks.ndim != 2:\n        raise ValueError(att_masks.ndim)\n    if pad_masks.ndim != 2:\n        raise ValueError(pad_masks.ndim)\n\n    cumsum = torch.cumsum(att_masks, dim=1)\n    att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]\n    pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]\n    return att_2d_masks & pad_2d_masks\n\n\nclass PI0Pytorch(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.pi05 = config.pi05\n\n        paligemma_config = _gemma.get_config(config.paligemma_variant)\n        action_expert_config = _gemma.get_config(config.action_expert_variant)\n\n        self.paligemma_with_expert = PaliGemmaWithExpertModel(\n            paligemma_config,\n            action_expert_config,\n            use_adarms=[False, True] if self.pi05 else [False, False],\n            precision=config.dtype,\n        )\n\n        self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width)\n        self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim)\n\n        if self.pi05:\n            self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)\n            self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)\n        else:\n            self.state_proj = nn.Linear(config.action_dim, action_expert_config.width)\n            self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)\n            self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)\n\n        torch.set_float32_matmul_precision(\"high\")\n        if config.pytorch_compile_mode is not None:\n            self.sample_actions = torch.compile(self.sample_actions, mode=config.pytorch_compile_mode)\n\n        # Initialize gradient checkpointing flag\n        self.gradient_checkpointing_enabled = False\n\n        msg = \"transformers_replace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/`.\"\n        try:\n            from transformers.models.siglip import check\n\n            if not check.check_whether_transformers_replace_is_installed_correctly():\n                raise ValueError(msg)\n        except ImportError:\n            raise ValueError(msg) from None\n\n    def gradient_checkpointing_enable(self):\n        \"\"\"Enable gradient checkpointing for memory optimization.\"\"\"\n        self.gradient_checkpointing_enabled = True\n        self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True\n        self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True\n        self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True\n\n        logging.info(\"Enabled gradient checkpointing for PI0Pytorch model\")\n\n    def gradient_checkpointing_disable(self):\n        \"\"\"Disable gradient checkpointing.\"\"\"\n        self.gradient_checkpointing_enabled = False\n        self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False\n        self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False\n        self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False\n\n        logging.info(\"Disabled gradient checkpointing for PI0Pytorch model\")\n\n    def is_gradient_checkpointing_enabled(self):\n        \"\"\"Check if gradient checkpointing is enabled.\"\"\"\n        return self.gradient_checkpointing_enabled\n\n    def _apply_checkpoint(self, func, *args, **kwargs):\n        \"\"\"Helper method to apply gradient checkpointing if enabled.\"\"\"\n        if self.gradient_checkpointing_enabled and self.training:\n            return torch.utils.checkpoint.checkpoint(\n                func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs\n            )\n        return func(*args, **kwargs)\n\n    def _prepare_attention_masks_4d(self, att_2d_masks):\n        \"\"\"Helper method to prepare 4D attention masks for transformer.\"\"\"\n        att_2d_masks_4d = att_2d_masks[:, None, :, :]\n        return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)\n\n    def _preprocess_observation(self, observation, *, train=True):\n        \"\"\"Helper method to preprocess observation.\"\"\"\n        observation = _preprocessing.preprocess_observation_pytorch(observation, train=train)\n        return (\n            list(observation.images.values()),\n            list(observation.image_masks.values()),\n            observation.tokenized_prompt,\n            observation.tokenized_prompt_mask,\n            observation.state,\n        )\n\n    def sample_noise(self, shape, device):\n        return torch.normal(\n            mean=0.0,\n            std=1.0,\n            size=shape,\n            dtype=torch.float32,\n            device=device,\n        )\n\n    def sample_time(self, bsize, device):\n        time_beta = sample_beta(1.5, 1.0, bsize, device)\n        time = time_beta * 0.999 + 0.001\n        return time.to(dtype=torch.float32, device=device)\n\n    def embed_prefix(\n        self, images, img_masks, lang_tokens, lang_masks\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        \"\"\"Embed images with SigLIP and language tokens with embedding layer to prepare\n        for PaliGemma transformer processing.\n        \"\"\"\n        embs = []\n        pad_masks = []\n        att_masks = []\n\n        # Process images\n        for img, img_mask in zip(images, img_masks, strict=True):\n\n            def image_embed_func(img):\n                return self.paligemma_with_expert.embed_image(img)\n\n            img_emb = self._apply_checkpoint(image_embed_func, img)\n\n            bsize, num_img_embs = img_emb.shape[:2]\n\n            embs.append(img_emb)\n            pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))\n\n            # Create attention masks so that image tokens attend to each other\n            att_masks += [0] * num_img_embs\n\n        # Process language tokens\n        def lang_embed_func(lang_tokens):\n            lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)\n            lang_emb_dim = lang_emb.shape[-1]\n            return lang_emb * math.sqrt(lang_emb_dim)\n\n        lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)\n\n        embs.append(lang_emb)\n        pad_masks.append(lang_masks)\n\n        # full attention between image and language inputs\n        num_lang_embs = lang_emb.shape[1]\n        att_masks += [0] * num_lang_embs\n\n        embs = torch.cat(embs, dim=1)\n        pad_masks = torch.cat(pad_masks, dim=1)\n        att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)\n\n        # Get batch size from the first dimension of the concatenated tensors\n        bsize = pad_masks.shape[0]\n        att_masks = att_masks[None, :].expand(bsize, len(att_masks))\n\n        return embs, pad_masks, att_masks\n\n    def embed_suffix(self, state, noisy_actions, timestep):\n        \"\"\"Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.\"\"\"\n        embs = []\n        pad_masks = []\n        att_masks = []\n\n        if not self.pi05:\n            if self.state_proj.weight.dtype == torch.float32:\n                state = state.to(torch.float32)\n\n            # Embed state\n            def state_proj_func(state):\n                return self.state_proj(state)\n\n            state_emb = self._apply_checkpoint(state_proj_func, state)\n\n            embs.append(state_emb[:, None, :])\n            bsize = state_emb.shape[0]\n            device = state_emb.device\n\n            state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)\n            pad_masks.append(state_mask)\n\n            # Set attention masks so that image and language inputs do not attend to state or actions\n            att_masks += [1]\n\n        # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]\n        time_emb = create_sinusoidal_pos_embedding(\n            timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0, device=timestep.device\n        )\n        time_emb = time_emb.type(dtype=timestep.dtype)\n\n        # Fuse timestep + action information using an MLP\n        def action_proj_func(noisy_actions):\n            return self.action_in_proj(noisy_actions)\n\n        action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)\n\n        if not self.pi05:\n            time_emb = time_emb[:, None, :].expand_as(action_emb)\n            action_time_emb = torch.cat([action_emb, time_emb], dim=2)\n\n            # Apply MLP layers\n            def mlp_func(action_time_emb):\n                x = self.action_time_mlp_in(action_time_emb)\n                x = F.silu(x)  # swish == silu\n                return self.action_time_mlp_out(x)\n\n            action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)\n            adarms_cond = None\n        else:\n            # time MLP (for adaRMS)\n            def time_mlp_func(time_emb):\n                x = self.time_mlp_in(time_emb)\n                x = F.silu(x)  # swish == silu\n                x = self.time_mlp_out(x)\n                return F.silu(x)\n\n            time_emb = self._apply_checkpoint(time_mlp_func, time_emb)\n            action_time_emb = action_emb\n            adarms_cond = time_emb\n\n        # Add to input tokens\n        embs.append(action_time_emb)\n\n        bsize, action_time_dim = action_time_emb.shape[:2]\n        action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)\n        pad_masks.append(action_time_mask)\n\n        # Set attention masks so that image, language and state inputs do not attend to action tokens\n        att_masks += [1] + ([0] * (self.config.action_horizon - 1))\n\n        embs = torch.cat(embs, dim=1)\n        pad_masks = torch.cat(pad_masks, dim=1)\n        att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)\n        att_masks = att_masks[None, :].expand(bsize, len(att_masks))\n\n        return embs, pad_masks, att_masks, adarms_cond\n\n    def forward(self, observation, actions, noise=None, time=None) -> Tensor:\n        \"\"\"Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)\"\"\"\n        images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=True)\n\n        if noise is None:\n            noise = self.sample_noise(actions.shape, actions.device)\n\n        if time is None:\n            time = self.sample_time(actions.shape[0], actions.device)\n\n        time_expanded = time[:, None, None]\n        x_t = time_expanded * noise + (1 - time_expanded) * actions\n        u_t = noise - actions\n\n        prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)\n        suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)\n        if (\n            self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype\n            == torch.bfloat16\n        ):\n            suffix_embs = suffix_embs.to(dtype=torch.bfloat16)\n            prefix_embs = prefix_embs.to(dtype=torch.bfloat16)\n\n        pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)\n        att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)\n\n        att_2d_masks = make_att_2d_masks(pad_masks, att_masks)\n        position_ids = torch.cumsum(pad_masks, dim=1) - 1\n\n        # Prepare attention masks\n        att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)\n\n        # Apply gradient checkpointing if enabled\n        def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):\n            (_, suffix_out), _ = self.paligemma_with_expert.forward(\n                attention_mask=att_2d_masks_4d,\n                position_ids=position_ids,\n                past_key_values=None,\n                inputs_embeds=[prefix_embs, suffix_embs],\n                use_cache=False,\n                adarms_cond=[None, adarms_cond],\n            )\n            return suffix_out\n\n        suffix_out = self._apply_checkpoint(\n            forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond\n        )\n\n        suffix_out = suffix_out[:, -self.config.action_horizon :]\n        suffix_out = suffix_out.to(dtype=torch.float32)\n\n        # Apply gradient checkpointing to final action projection if enabled\n        def action_out_proj_func(suffix_out):\n            return self.action_out_proj(suffix_out)\n\n        v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)\n\n        return F.mse_loss(u_t, v_t, reduction=\"none\")\n\n    @torch.no_grad()\n    def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor:\n        \"\"\"Do a full inference forward and compute the action (batch_size x num_steps x num_motors)\"\"\"\n        bsize = observation.state.shape[0]\n        if noise is None:\n            actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)\n            noise = self.sample_noise(actions_shape, device)\n\n        images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=False)\n\n        prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)\n        prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)\n        prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1\n\n        # Compute image and language key value cache\n        prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)\n        self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = \"eager\"  # noqa: SLF001\n\n        _, past_key_values = self.paligemma_with_expert.forward(\n            attention_mask=prefix_att_2d_masks_4d,\n            position_ids=prefix_position_ids,\n            past_key_values=None,\n            inputs_embeds=[prefix_embs, None],\n            use_cache=True,\n        )\n\n        dt = -1.0 / num_steps\n        dt = torch.tensor(dt, dtype=torch.float32, device=device)\n\n        x_t = noise\n        time = torch.tensor(1.0, dtype=torch.float32, device=device)\n        while time >= -dt / 2:\n            expanded_time = time.expand(bsize)\n            v_t = self.denoise_step(\n                state,\n                prefix_pad_masks,\n                past_key_values,\n                x_t,\n                expanded_time,\n            )\n\n            # Euler step - use new tensor assignment instead of in-place operation\n            x_t = x_t + dt * v_t\n            time += dt\n        return x_t\n\n    def denoise_step(\n        self,\n        state,\n        prefix_pad_masks,\n        past_key_values,\n        x_t,\n        timestep,\n    ):\n        \"\"\"Apply one denoising step of the noise `x_t` at a given timestep.\"\"\"\n        suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)\n\n        suffix_len = suffix_pad_masks.shape[1]\n        batch_size = prefix_pad_masks.shape[0]\n        prefix_len = prefix_pad_masks.shape[1]\n\n        prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)\n\n        suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)\n\n        full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)\n\n        prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]\n        position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1\n\n        # Prepare attention masks\n        full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)\n        self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = \"eager\"  # noqa: SLF001\n\n        outputs_embeds, _ = self.paligemma_with_expert.forward(\n            attention_mask=full_att_2d_masks_4d,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=[None, suffix_embs],\n            use_cache=False,\n            adarms_cond=[None, adarms_cond],\n        )\n\n        suffix_out = outputs_embeds[1]\n        suffix_out = suffix_out[:, -self.config.action_horizon :]\n        suffix_out = suffix_out.to(dtype=torch.float32)\n        return self.action_out_proj(suffix_out)\n"
  },
  {
    "path": "src/openpi/models_pytorch/preprocessing_pytorch.py",
    "content": "from collections.abc import Sequence\nimport logging\n\nimport torch\n\nfrom openpi.shared import image_tools\n\nlogger = logging.getLogger(\"openpi\")\n\n# Constants moved from model.py\nIMAGE_KEYS = (\n    \"base_0_rgb\",\n    \"left_wrist_0_rgb\",\n    \"right_wrist_0_rgb\",\n)\n\nIMAGE_RESOLUTION = (224, 224)\n\n\ndef preprocess_observation_pytorch(\n    observation,\n    *,\n    train: bool = False,\n    image_keys: Sequence[str] = IMAGE_KEYS,\n    image_resolution: tuple[int, int] = IMAGE_RESOLUTION,\n):\n    \"\"\"Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations.\n\n    This function avoids complex type annotations that can cause torch.compile issues.\n    \"\"\"\n    if not set(image_keys).issubset(observation.images):\n        raise ValueError(f\"images dict missing keys: expected {image_keys}, got {list(observation.images)}\")\n\n    batch_shape = observation.state.shape[:-1]\n\n    out_images = {}\n    for key in image_keys:\n        image = observation.images[key]\n\n        # TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats\n        # Handle both [B, C, H, W] and [B, H, W, C] formats\n        is_channels_first = image.shape[1] == 3  # Check if channels are in dimension 1\n\n        if is_channels_first:\n            # Convert [B, C, H, W] to [B, H, W, C] for processing\n            image = image.permute(0, 2, 3, 1)\n\n        if image.shape[1:3] != image_resolution:\n            logger.info(f\"Resizing image {key} from {image.shape[1:3]} to {image_resolution}\")\n            image = image_tools.resize_with_pad_torch(image, *image_resolution)\n\n        if train:\n            # Convert from [-1, 1] to [0, 1] for PyTorch augmentations\n            image = image / 2.0 + 0.5\n\n            # Apply PyTorch-based augmentations\n            if \"wrist\" not in key:\n                # Geometric augmentations for non-wrist cameras\n                height, width = image.shape[1:3]\n\n                # Random crop and resize\n                crop_height = int(height * 0.95)\n                crop_width = int(width * 0.95)\n\n                # Random crop\n                max_h = height - crop_height\n                max_w = width - crop_width\n                if max_h > 0 and max_w > 0:\n                    # Use tensor operations instead of .item() for torch.compile compatibility\n                    start_h = torch.randint(0, max_h + 1, (1,), device=image.device)\n                    start_w = torch.randint(0, max_w + 1, (1,), device=image.device)\n                    image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]\n\n                # Resize back to original size\n                image = torch.nn.functional.interpolate(\n                    image.permute(0, 3, 1, 2),  # [b, h, w, c] -> [b, c, h, w]\n                    size=(height, width),\n                    mode=\"bilinear\",\n                    align_corners=False,\n                ).permute(0, 2, 3, 1)  # [b, c, h, w] -> [b, h, w, c]\n\n                # Random rotation (small angles)\n                # Use tensor operations instead of .item() for torch.compile compatibility\n                angle = torch.rand(1, device=image.device) * 10 - 5  # Random angle between -5 and 5 degrees\n                if torch.abs(angle) > 0.1:  # Only rotate if angle is significant\n                    # Convert to radians\n                    angle_rad = angle * torch.pi / 180.0\n\n                    # Create rotation matrix\n                    cos_a = torch.cos(angle_rad)\n                    sin_a = torch.sin(angle_rad)\n\n                    # Apply rotation using grid_sample\n                    grid_x = torch.linspace(-1, 1, width, device=image.device)\n                    grid_y = torch.linspace(-1, 1, height, device=image.device)\n\n                    # Create meshgrid\n                    grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing=\"ij\")\n\n                    # Expand to batch dimension\n                    grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1)\n                    grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1)\n\n                    # Apply rotation transformation\n                    grid_x_rot = grid_x * cos_a - grid_y * sin_a\n                    grid_y_rot = grid_x * sin_a + grid_y * cos_a\n\n                    # Stack and reshape for grid_sample\n                    grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)\n\n                    image = torch.nn.functional.grid_sample(\n                        image.permute(0, 3, 1, 2),  # [b, h, w, c] -> [b, c, h, w]\n                        grid,\n                        mode=\"bilinear\",\n                        padding_mode=\"zeros\",\n                        align_corners=False,\n                    ).permute(0, 2, 3, 1)  # [b, c, h, w] -> [b, h, w, c]\n\n            # Color augmentations for all cameras\n            # Random brightness\n            # Use tensor operations instead of .item() for torch.compile compatibility\n            brightness_factor = 0.7 + torch.rand(1, device=image.device) * 0.6  # Random factor between 0.7 and 1.3\n            image = image * brightness_factor\n\n            # Random contrast\n            # Use tensor operations instead of .item() for torch.compile compatibility\n            contrast_factor = 0.6 + torch.rand(1, device=image.device) * 0.8  # Random factor between 0.6 and 1.4\n            mean = image.mean(dim=[1, 2, 3], keepdim=True)\n            image = (image - mean) * contrast_factor + mean\n\n            # Random saturation (convert to HSV, modify S, convert back)\n            # For simplicity, we'll just apply a random scaling to the color channels\n            # Use tensor operations instead of .item() for torch.compile compatibility\n            saturation_factor = 0.5 + torch.rand(1, device=image.device) * 1.0  # Random factor between 0.5 and 1.5\n            gray = image.mean(dim=-1, keepdim=True)\n            image = gray + (image - gray) * saturation_factor\n\n            # Clamp values to [0, 1]\n            image = torch.clamp(image, 0, 1)\n\n            # Back to [-1, 1]\n            image = image * 2.0 - 1.0\n\n        # Convert back to [B, C, H, W] format if it was originally channels-first\n        if is_channels_first:\n            image = image.permute(0, 3, 1, 2)  # [B, H, W, C] -> [B, C, H, W]\n\n        out_images[key] = image\n\n    # obtain mask\n    out_masks = {}\n    for key in out_images:\n        if key not in observation.image_masks:\n            # do not mask by default\n            out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device)\n        else:\n            out_masks[key] = observation.image_masks[key]\n\n    # Create a simple object with the required attributes instead of using the complex Observation class\n    class SimpleProcessedObservation:\n        def __init__(self, **kwargs):\n            for key, value in kwargs.items():\n                setattr(self, key, value)\n\n    return SimpleProcessedObservation(\n        images=out_images,\n        image_masks=out_masks,\n        state=observation.state,\n        tokenized_prompt=observation.tokenized_prompt,\n        tokenized_prompt_mask=observation.tokenized_prompt_mask,\n        token_ar_mask=observation.token_ar_mask,\n        token_loss_mask=observation.token_loss_mask,\n    )\n"
  },
  {
    "path": "src/openpi/models_pytorch/transformers_replace/models/gemma/configuration_gemma.py",
    "content": "#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n#           This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.\n#               Do NOT edit this file manually as any edits will be overwritten by the generation of\n#             the file from the modular. If any change should be done, please apply the change to the\n#                          modular_gemma.py file directly. One of our CI enforces this.\n#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n# coding=utf-8\n# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.\n#\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import Optional\nfrom ...configuration_utils import PretrainedConfig\n\n\nclass GemmaConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma\n    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the\n    defaults will yield a similar configuration to that of the Gemma-7B.\n    e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n    Args:\n        vocab_size (`int`, *optional*, defaults to 256000):\n            Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`GemmaModel`]\n        hidden_size (`int`, *optional*, defaults to 3072):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 24576):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 28):\n            Number of hidden layers in the Transformer decoder.\n        num_attention_heads (`int`, *optional*, defaults to 16):\n            Number of attention heads for each attention layer in the Transformer decoder.\n        num_key_value_heads (`int`, *optional*, defaults to 16):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details, check out [this\n            paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to\n            `num_attention_heads`.\n        head_dim (`int`, *optional*, defaults to 256):\n            The attention head dimension.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"gelu_pytorch_tanh\"`):\n            The legacy activation function. It is overwritten by the `hidden_activation`.\n        hidden_activation (`str` or `function`, *optional*):\n            The non-linear activation function (function or string) in the decoder. Will default to `\"gelu_pytorch_tanh\"`\n            if not specified. `\"gelu_pytorch_tanh\"` uses an approximation of the `\"gelu\"` activation function.\n        max_position_embeddings (`int`, *optional*, defaults to 8192):\n            The maximum sequence length that this model might ever be used with.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        pad_token_id (`int`, *optional*, defaults to 0):\n            Padding token id.\n        eos_token_id (`int`, *optional*, defaults to 1):\n            End of stream token id.\n        bos_token_id (`int`, *optional*, defaults to 2):\n            Beginning of stream token id.\n        tie_word_embeddings (`bool`, *optional*, defaults to `True`):\n            Whether to tie weight embeddings\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):\n            Whether to use a bias in the query, key, value and output projection layers during self-attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        use_adarms (`bool`, *optional*, defaults to `False`):\n            Whether to use ADARMS.\n        adarms_cond_dim (`int`, *optional*, defaults to `None`):\n            The dimension of the ADARMS condition.\n    ```python\n    >>> from transformers import GemmaModel, GemmaConfig\n    >>> # Initializing a Gemma gemma-7b style configuration\n    >>> configuration = GemmaConfig()\n    >>> # Initializing a model from the gemma-7b style configuration\n    >>> model = GemmaModel(configuration)\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"gemma\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n    base_model_tp_plan = {\n        \"layers.*.self_attn.q_proj\": \"colwise\",\n        \"layers.*.self_attn.k_proj\": \"colwise\",\n        \"layers.*.self_attn.v_proj\": \"colwise\",\n        \"layers.*.self_attn.o_proj\": \"rowwise\",\n        \"layers.*.mlp.gate_proj\": \"colwise\",\n        \"layers.*.mlp.up_proj\": \"colwise\",\n        \"layers.*.mlp.down_proj\": \"rowwise\",\n    }\n    base_model_pp_plan = {\n        \"embed_tokens\": ([\"input_ids\"], [\"inputs_embeds\"]),\n        \"layers\": ([\"hidden_states\", \"attention_mask\"], [\"hidden_states\"]),\n        \"norm\": ([\"hidden_states\"], [\"hidden_states\"]),\n    }\n\n    def __init__(\n        self,\n        vocab_size=256000,\n        hidden_size=3072,\n        intermediate_size=24576,\n        num_hidden_layers=28,\n        num_attention_heads=16,\n        num_key_value_heads=16,\n        head_dim=256,\n        hidden_act=\"gelu_pytorch_tanh\",\n        hidden_activation=None,\n        max_position_embeddings=8192,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=0,\n        eos_token_id=1,\n        bos_token_id=2,\n        tie_word_embeddings=True,\n        rope_theta=10000.0,\n        attention_bias=False,\n        attention_dropout=0.0,\n        use_adarms: bool = False,\n        adarms_cond_dim: Optional[int] = None,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.head_dim = head_dim\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.hidden_activation = hidden_activation\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n        self.use_adarms = use_adarms\n        self.adarms_cond_dim = adarms_cond_dim\n\n        # Set default for adarms_cond_dim if use_adarms is True\n        if self.use_adarms and self.adarms_cond_dim is None:\n            self.adarms_cond_dim = self.hidden_size\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\n__all__ = [\"GemmaConfig\"]"
  },
  {
    "path": "src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py",
    "content": "#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n#           This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.\n#               Do NOT edit this file manually as any edits will be overwritten by the generation of\n#             the file from the modular. If any change should be done, please apply the change to the\n#                          modular_gemma.py file directly. One of our CI enforces this.\n#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n# coding=utf-8\n# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.\n#\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import Callable, Optional, Union\n\nimport torch\nfrom torch import nn\n\nfrom ...activations import ACT2FN\nfrom ...cache_utils import Cache, DynamicCache\nfrom ...generation import GenerationMixin\nfrom ...masking_utils import create_causal_mask\nfrom ...modeling_flash_attention_utils import FlashAttentionKwargs\nfrom ...modeling_layers import GradientCheckpointingLayer\nfrom ...modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n    SequenceClassifierOutputWithPast,\n    TokenClassifierOutput,\n)\nfrom ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update\nfrom ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel\nfrom ...processing_utils import Unpack\nfrom ...utils import LossKwargs, auto_docstring, can_return_tuple, logging\nfrom .configuration_gemma import GemmaConfig\n\n\nlogger = logging.get_logger(__name__)\n\n\nclass GemmaRMSNorm(nn.Module):\n    def __init__(self, dim: int, eps: float = 1e-6, cond_dim: Optional[int] = None):\n        super().__init__()\n        self.eps = eps\n        self.dim = dim\n        self.cond_dim = cond_dim\n        \n        # Dense layer for adaptive normalization (if cond_dim is provided)\n        if cond_dim is not None:\n            #self.dense = nn.Linear(cond_dim, dim * 3, bias=True, dtype=torch.bfloat16)\n            self.dense = nn.Linear(cond_dim, dim * 3, bias=True)\n            # Initialize with zeros (matches source implementation)\n            nn.init.zeros_(self.dense.weight)\n        else:\n            self.weight = nn.Parameter(torch.zeros(dim, dtype=torch.bfloat16))\n            self.dense = None\n\n    def _norm(self, x):\n        # Compute variance in float32 (like the source implementation)\n        var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)\n        # Compute normalization in float32\n        normed_inputs = x * torch.rsqrt(var + self.eps)\n        return normed_inputs\n\n    def forward(self, x, cond=None):\n        dtype = x.dtype  # original dtype, could be half-precision\n        normed_inputs = self._norm(x)\n        \n        if cond is None or self.dense is None:\n            # regular RMSNorm\n            # scale by learned parameter in float32 (matches source implementation)\n            normed_inputs = normed_inputs * (1.0 + self.weight.float())\n            return normed_inputs.to(dtype), None  # return in original dtype with None gate\n        \n        # adaptive RMSNorm (if cond is provided and dense layer exists)\n        if cond.shape[-1] != self.cond_dim:\n            raise ValueError(f\"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}\")\n        \n        #self.dense.to(dtype=torch.bfloat16).to(dtype=torch.float32)\n        modulation = self.dense(cond)\n        # Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features]\n        if len(x.shape) == 3:  # [batch, seq, features]\n            modulation = modulation.unsqueeze(1)\n        \n        scale, shift, gate = torch.chunk(modulation, 3, dim=-1)\n        \n        # Apply adaptive normalization: use model weight dtype to ensure compatibility\n        # model_dtype = self.dense.weight.dtype  # Use the model's dtype (bfloat16)\n        # scale = scale.to(model_dtype)\n        # shift = shift.to(model_dtype)\n        # gate = gate.to(model_dtype)\n        # normed_inputs = normed_inputs.to(model_dtype)  # Convert normed_inputs to model dtype\n        \n        normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to(torch.float32)\n\n        return normed_inputs.to(dtype), gate.to(dtype)\n\n    def extra_repr(self):\n        repr_str = f\"{tuple(self.weight.shape)}, eps={self.eps}\"\n        if self.dense is not None:\n            repr_str += f\", adaptive=True, cond_dim={self.cond_dim}\"\n        return repr_str\n\n\nclass GemmaMLP(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n        return down_proj\n\n\nclass GemmaRotaryEmbedding(nn.Module):\n    def __init__(self, config: GemmaConfig, device=None):\n        super().__init__()\n        # BC: \"rope_type\" was originally \"type\"\n        if hasattr(config, \"rope_scaling\") and config.rope_scaling is not None:\n            self.rope_type = config.rope_scaling.get(\"rope_type\", config.rope_scaling.get(\"type\"))\n        else:\n            self.rope_type = \"default\"\n        self.max_seq_len_cached = config.max_position_embeddings\n        self.original_max_seq_len = config.max_position_embeddings\n\n        self.config = config\n        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]\n\n        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.original_inv_freq = self.inv_freq\n\n    @torch.no_grad()\n    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)\n    def forward(self, x, position_ids):\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)\n        position_ids_expanded = position_ids[:, None, :].float()\n\n        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):  # Force float32\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos() * self.attention_scaling\n            sin = emb.sin() * self.attention_scaling\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`, *optional*):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\ndef _gated_residual(x, y, gate):\n    \"\"\"\n    Applies gated residual connection with optional gate parameter.\n    \n    Args:\n        x: Input tensor (residual)\n        y: Output tensor to be added\n        gate: Optional gate tensor to modulate the addition\n        \n    Returns:\n        x + y if gate is None, otherwise x + y * gate\n    \"\"\"\n    if x is None and y is None:\n        return None\n    if x is None or y is None:\n        return x if x is not None else y\n    if gate is None:\n        return x + y\n    return x + y * gate\n\n\ndef eager_attention_forward(\n    module: nn.Module,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attention_mask: Optional[torch.Tensor],\n    scaling: float,\n    dropout: float = 0.0,\n    **kwargs,\n):\n    key_states = repeat_kv(key, module.num_key_value_groups)\n    value_states = repeat_kv(value, module.num_key_value_groups)\n\n    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling\n    if attention_mask is not None:\n        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n        attn_weights = attn_weights + causal_mask\n\n    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)\n    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)\n    attn_output = torch.matmul(attn_weights, value_states)\n    attn_output = attn_output.transpose(1, 2).contiguous()\n\n    return attn_output, attn_weights\n\n\nclass GemmaAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: GemmaConfig, layer_idx: int):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        self.head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads\n        self.scaling = self.head_dim**-0.5\n        self.attention_dropout = config.attention_dropout\n        self.is_causal = True\n\n        self.q_proj = nn.Linear(\n            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.k_proj = nn.Linear(\n            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.v_proj = nn.Linear(\n            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.o_proj = nn.Linear(\n            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_embeddings: tuple[torch.Tensor, torch.Tensor],\n        attention_mask: Optional[torch.Tensor],\n        past_key_value: Optional[Cache] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        use_cache: bool = False,\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:        \n        input_shape = hidden_states.shape[:-1]\n        hidden_shape = (*input_shape, -1, self.head_dim)\n\n        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n\n        cos, sin = position_embeddings\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n        # Use cache if provided\n        if past_key_value is not None:\n            if use_cache:\n                # sin and cos are specific to RoPE models; cache_position needed for the static cache\n                cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n                key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n            else:\n                key_states = torch.cat([past_key_value[self.layer_idx][0], key_states], dim=2)\n                value_states = torch.cat([past_key_value[self.layer_idx][1], value_states], dim=2)\n\n        attention_interface: Callable = eager_attention_forward\n        if self.config._attn_implementation != \"eager\":\n            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]\n\n        attn_output, attn_weights = attention_interface(\n            self,\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            dropout=0.0 if not self.training else self.attention_dropout,\n            scaling=self.scaling,\n            **kwargs,\n        )\n\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n        return attn_output, attn_weights\n\n\nclass GemmaDecoderLayer(GradientCheckpointingLayer):\n    def __init__(self, config: GemmaConfig, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)\n\n        self.mlp = GemmaMLP(config)\n        cond_dim = getattr(config, 'adarms_cond_dim', None) if getattr(config, 'use_adarms', False) else None\n        self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)\n        self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC\n        adarms_cond: Optional[torch.Tensor] = None,\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        residual = hidden_states\n        hidden_states, gate = self.input_layernorm(hidden_states, adarms_cond)\n\n        # Self Attention\n        hidden_states, self_attn_weights = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            position_embeddings=position_embeddings,\n            **kwargs,\n        )\n        hidden_states = _gated_residual(residual, hidden_states, gate)\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states, gate = self.post_attention_layernorm(hidden_states, adarms_cond)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = _gated_residual(residual, hidden_states, gate)\n\n        outputs = (hidden_states,)\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        return outputs\n\n\n@auto_docstring\nclass GemmaPreTrainedModel(PreTrainedModel):\n    config_class = GemmaConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"GemmaDecoderLayer\"]\n    _skip_keys_device_placement = [\"past_key_values\"]\n    _supports_flash_attn_3 = True\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _supports_flex_attn = True\n    _supports_cache_class = True\n    _supports_quantized_cache = True\n    _supports_static_cache = True\n    _supports_attention_backend = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n        elif isinstance(module, GemmaRMSNorm):\n            if hasattr(module, 'weight'):\n                module.weight.data.fill_(1.0)\n\n\n@auto_docstring\nclass GemmaModel(GemmaPreTrainedModel):\n    def __init__(self, config: GemmaConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = nn.ModuleList(\n            [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]\n        )\n\n        cond_dim = getattr(config, 'adarms_cond_dim', None) if getattr(config, 'use_adarms', False) else None\n        self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)\n        self.rotary_emb = GemmaRotaryEmbedding(config=config)\n        self.gradient_checkpointing = False\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    @can_return_tuple\n    @auto_docstring\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        adarms_cond: Optional[torch.Tensor] = None,\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> BaseModelOutputWithPast:\n        \"\"\"\n        adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):\n            Condition for ADARMS.\n        \"\"\"\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")\n\n        if self.gradient_checkpointing and self.training and use_cache:\n            logger.warning_once(\n                \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\"\n            )\n            use_cache = False\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        if use_cache and past_key_values is None:\n            past_key_values = DynamicCache()\n\n        if cache_position is None:\n            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n            cache_position = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n            )\n\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = create_causal_mask(\n            config=self.config,\n            input_embeds=inputs_embeds,\n            attention_mask=attention_mask,\n            cache_position=cache_position,\n            past_key_values=past_key_values,\n            position_ids=position_ids,\n        )\n\n        # embed positions\n        hidden_states = inputs_embeds\n        # Convert to bfloat16 if the first layer uses bfloat16\n        if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:\n            hidden_states = hidden_states.to(torch.bfloat16)\n\n        # create position embeddings to be shared across the decoder layers\n        position_embeddings = self.rotary_emb(hidden_states, position_ids)\n\n        # normalized\n        # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5\n        # See https://github.com/huggingface/transformers/pull/29402\n        normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)\n        #hidden_states = hidden_states * normalizer\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n\n        for decoder_layer in self.layers[: self.config.num_hidden_layers]:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            layer_outputs = decoder_layer(\n                hidden_states,\n                attention_mask=causal_mask,\n                position_ids=position_ids,\n                past_key_value=past_key_values,\n                output_attentions=output_attentions,\n                use_cache=use_cache,\n                cache_position=cache_position,\n                position_embeddings=position_embeddings,\n                adarms_cond=adarms_cond,\n                **kwargs,\n            )\n\n            hidden_states = layer_outputs[0]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states, _ = self.norm(hidden_states, adarms_cond)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=past_key_values if use_cache else None,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n\nclass KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...\n\n\n@auto_docstring\nclass GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):\n    _tied_weights_keys = [\"lm_head.weight\"]\n    _tp_plan = {\"lm_head\": \"colwise_rep\"}\n    _pp_plan = {\"lm_head\": ([\"hidden_states\"], [\"logits\"])}\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = GemmaModel(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @can_return_tuple\n    @auto_docstring\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        adarms_cond: Optional[torch.Tensor] = None,\n        **kwargs: Unpack[KwargsForCausalLM],\n    ) -> CausalLMOutputWithPast:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):\n            Condition for ADARMS.\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, GemmaForCausalLM\n\n        >>> model = GemmaForCausalLM.from_pretrained(\"google/gemma-7b\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-7b\")\n\n        >>> prompt = \"What is your favorite condiment?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"What is your favorite condiment?\"\n        ```\"\"\"\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs: BaseModelOutputWithPast = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            cache_position=cache_position,\n            adarms_cond=adarms_cond,\n            **kwargs,\n        )\n\n        hidden_states = outputs.last_hidden_state\n        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss\n        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep\n        logits = self.lm_head(hidden_states[:, slice_indices, :])\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n@auto_docstring(\n    custom_intro=\"\"\"\n    The Gemma Model transformer with a sequence classification head on top (linear layer).\n\n    [`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models\n    (e.g. GPT-2) do.\n\n    Since it does classification on the last token, it requires to know the position of the last token. If a\n    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If\n    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the\n    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in\n    each row of the batch).\n    \"\"\"\n)\nclass GemmaForSequenceClassification(GemmaPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = GemmaModel(config)\n        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @can_return_tuple\n    @auto_docstring\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        adarms_cond: Optional[torch.Tensor] = None,\n    ) -> SequenceClassifierOutputWithPast:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n\n        adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):\n            Condition for ADARMS.\n        \"\"\"\n\n        transformer_outputs: BaseModelOutputWithPast = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            adarms_cond=adarms_cond,\n        )\n        hidden_states = transformer_outputs.last_hidden_state\n        logits = self.score(hidden_states)\n\n        if input_ids is not None:\n            batch_size = input_ids.shape[0]\n        else:\n            batch_size = inputs_embeds.shape[0]\n\n        if self.config.pad_token_id is None and batch_size != 1:\n            raise ValueError(\"Cannot handle batch sizes > 1 if no padding token is defined.\")\n        if self.config.pad_token_id is None:\n            last_non_pad_token = -1\n        elif input_ids is not None:\n            # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id\n            non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)\n            token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)\n            last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)\n        else:\n            last_non_pad_token = -1\n            logger.warning_once(\n                f\"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be \"\n                \"unexpected if using padding tokens in conjunction with `inputs_embeds.`\"\n            )\n\n        pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)\n\n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=pooled_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n\n\n@auto_docstring\nclass GemmaForTokenClassification(GemmaPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = GemmaModel(config)\n        if getattr(config, \"classifier_dropout\", None) is not None:\n            classifier_dropout = config.classifier_dropout\n        elif getattr(config, \"hidden_dropout\", None) is not None:\n            classifier_dropout = config.hidden_dropout\n        else:\n            classifier_dropout = 0.1\n        self.dropout = nn.Dropout(classifier_dropout)\n        self.score = nn.Linear(config.hidden_size, config.num_labels)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @can_return_tuple\n    @auto_docstring\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        adarms_cond: Optional[torch.Tensor] = None,\n    ) -> TokenClassifierOutput:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n\n        adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):\n            Condition for ADARMS.\n        \"\"\"\n\n        outputs: BaseModelOutputWithPast = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            adarms_cond=adarms_cond,\n        )\n        sequence_output = outputs.last_hidden_state\n        sequence_output = self.dropout(sequence_output)\n        logits = self.score(sequence_output)\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits, labels, self.config)\n\n        return TokenClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n__all__ = [\n    \"GemmaModel\",\n    \"GemmaForCausalLM\",\n    \"GemmaForSequenceClassification\",\n    \"GemmaForTokenClassification\",\n    \"GemmaPreTrainedModel\",\n]\n"
  },
  {
    "path": "src/openpi/models_pytorch/transformers_replace/models/paligemma/modeling_paligemma.py",
    "content": "# coding=utf-8\n# Copyright 2024 the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch PaliGemmamodel.\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Optional, Union\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\n\nfrom ...cache_utils import Cache, HybridCache, StaticCache\nfrom ...generation import GenerationMixin\nfrom ...modeling_flash_attention_utils import FlashAttentionKwargs\nfrom ...modeling_outputs import BaseModelOutputWithPast\nfrom ...modeling_utils import PreTrainedModel\nfrom ...processing_utils import Unpack\nfrom ...utils import LossKwargs, ModelOutput, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging\nfrom ..auto import AutoModel\nfrom .configuration_paligemma import PaliGemmaConfig\n\n\nlogger = logging.get_logger(__name__)\n\n\n@dataclass\n@auto_docstring(\n    custom_intro=\"\"\"\n    Base class for Paligemma outputs, with hidden states and attentions.\n    \"\"\"\n)\nclass PaligemmaModelOutputWithPast(BaseModelOutputWithPast):\n    r\"\"\"\n    past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):\n        Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape\n        `(batch_size, num_heads, sequence_length, embed_size_per_head)`)\n\n        Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see\n        `past_key_values` input) to speed up sequential decoding.\n    image_hidden_states (`torch.FloatTensor`, *optional*):\n        A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.\n        image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.\n    \"\"\"\n\n    image_hidden_states: Optional[torch.FloatTensor] = None\n\n\n@dataclass\n@auto_docstring(\n    custom_intro=\"\"\"\n    Base class for PaliGemma causal language model (or autoregressive) outputs.\n    \"\"\"\n)\nclass PaliGemmaCausalLMOutputWithPast(ModelOutput):\n    r\"\"\"\n    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):\n        Language modeling loss (for next-token prediction).\n    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):\n        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).\n    past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):\n        Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape\n        `(batch_size, num_heads, sequence_length, embed_size_per_head)`)\n\n        Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see\n        `past_key_values` input) to speed up sequential decoding.\n    image_hidden_states (`torch.FloatTensor`, *optional*):\n        A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.\n        image_hidden_states of the model produced by the vision encoder after projecting last hidden state.\n    \"\"\"\n\n    loss: Optional[torch.FloatTensor] = None\n    logits: Optional[torch.FloatTensor] = None\n    past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None\n    hidden_states: Optional[tuple[torch.FloatTensor]] = None\n    attentions: Optional[tuple[torch.FloatTensor]] = None\n    image_hidden_states: Optional[torch.FloatTensor] = None\n\n\nclass PaliGemmaMultiModalProjector(nn.Module):\n    def __init__(self, config: PaliGemmaConfig):\n        super().__init__()\n        self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True)\n\n    def forward(self, image_features):\n        hidden_states = self.linear(image_features)\n\n        return hidden_states\n\n\n@auto_docstring\nclass PaliGemmaPreTrainedModel(PreTrainedModel):\n    config_class = PaliGemmaConfig\n    base_model_prefix = \"\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"PaliGemmaMultiModalProjector\"]\n    _skip_keys_device_placement = \"past_key_values\"\n    _supports_cache_class = True\n    _supports_quantized_cache = True\n    _supports_static_cache = True\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _supports_flex_attn = True\n    _supports_attention_backend = True\n\n    def _init_weights(self, module):\n        # important: this ported version of PaliGemmaisn't meant for training from scratch - only\n        # inference and fine-tuning\n        std = getattr(self.config, \"initializer_range\", self.config.get_text_config().initializer_range)\n\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n\n\n@auto_docstring(\n    custom_intro=\"\"\"\n    The Base Paligemma model which consists of a vision backbone and a language model withou language modeling head.,\n    \"\"\"\n)\nclass PaliGemmaModel(PaliGemmaPreTrainedModel):\n    _checkpoint_conversion_mapping = {\"language_model.model\": \"language_model\"}\n    # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch\n    accepts_loss_kwargs = False\n\n    def __init__(self, config: PaliGemmaConfig):\n        super().__init__(config)\n        self.vision_tower = AutoModel.from_config(config=config.vision_config)\n        self.multi_modal_projector = PaliGemmaMultiModalProjector(config)\n        self.vocab_size = config.text_config.vocab_size\n\n        language_model = AutoModel.from_config(config=config.text_config)\n        self.language_model = language_model\n\n        self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1\n        self.post_init()\n\n    # Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma\n    def get_input_embeddings(self):\n        return self.language_model.get_input_embeddings()\n\n    # Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma\n    def set_input_embeddings(self, value):\n        self.language_model.set_input_embeddings(value)\n\n    def set_decoder(self, decoder):\n        self.language_model = decoder\n\n    def get_decoder(self):\n        return self.language_model\n\n    def _update_causal_mask(\n        self,\n        attention_mask,\n        token_type_ids=None,\n        past_key_values=None,\n        cache_position=None,\n        input_tensor=None,\n        is_training: Optional[bool] = None,\n    ):\n        if self.config.text_config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and 0.0 in attention_mask:\n                return attention_mask\n            return None\n        is_training = is_training if is_training is not None else self.training\n        using_static_cache = isinstance(past_key_values, StaticCache)\n        min_dtype = torch.finfo(self.dtype).min\n        if input_tensor is None:\n            input_tensor = attention_mask\n\n        inputs_lead_dim, sequence_length = input_tensor.shape[:2]\n        if using_static_cache:\n            target_length = past_key_values.get_max_cache_shape()\n        elif isinstance(past_key_values, HybridCache):\n            target_length = past_key_values.get_max_cache_shape()\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else cache_position[0] + sequence_length + 1\n            )\n\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.\n            return attention_mask\n\n        causal_mask = torch.full(\n            (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device\n        )\n        # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below\n        if sequence_length != 1:\n            if is_training:\n                causal_mask = torch.triu(causal_mask, diagonal=1)\n            else:\n                causal_mask[:, :sequence_length] = 0.0\n\n        causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)\n        causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)\n        if attention_mask is not None:\n            causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit\n            mask_length = attention_mask.shape[-1]\n\n            # First unmask prefix tokens during training\n            if is_training:\n                if token_type_ids is None:\n                    raise ValueError(\"Token type ids must be provided during training\")\n                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(\n                    token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0\n                )\n\n            # Then apply padding mask (will mask pad tokens)\n            padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)\n            padding_mask = padding_mask == 0\n            causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(\n                padding_mask, min_dtype\n            )\n\n        return causal_mask\n\n    def get_image_features(self, pixel_values: torch.FloatTensor):\n        \"\"\"\n        Obtains image last hidden states from the vision tower and apply multimodal projection.\n\n        Args:\n            pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)\n               The tensors corresponding to the input images.\n        Returns:\n            image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).\n        \"\"\"\n        image_outputs = self.vision_tower(pixel_values)\n        selected_image_feature = image_outputs.last_hidden_state\n        image_features = self.multi_modal_projector(selected_image_feature)\n        return image_features\n\n    @can_return_tuple\n    @auto_docstring\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        pixel_values: torch.FloatTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,\n        token_type_ids: Optional[torch.LongTensor] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Union[tuple, PaligemmaModelOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n            config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.\n\n        Example:\n\n        ```python\n        >>> from PIL import Image\n        >>> import requests\n        >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration\n\n        >>> model = PaliGemmaForConditionalGeneration.from_pretrained(\"google/paligemma2-3b-mix-224\")\n        >>> processor = AutoProcessor.from_pretrained(\"google/paligemma2-3b-mix-224\")\n\n        >>> prompt = \"Where is the cat standing?\"\n        >>> url = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg\"\n        >>> image = Image.open(requests.get(url, stream=True).raw)\n\n        >>> inputs = processor(images=image, text=prompt,  return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(**inputs,)\n        >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Where is the cat standing?\\nsnow\"\n        ```\"\"\"\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")\n\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        is_training = token_type_ids is not None and labels is not None\n\n        # Replace image id woth PAD if the image token if OOV, to avoid index-errors\n        if input_ids is not None and self.config.image_token_id >= self.vocab_size:\n            special_image_mask = input_ids == self.config.image_token_id\n            llm_input_ids = input_ids.clone()\n            llm_input_ids[special_image_mask] = 0\n        else:\n            llm_input_ids = input_ids\n\n        if inputs_embeds is None:\n            inputs_embeds = self.get_input_embeddings()(llm_input_ids)\n\n        if cache_position is None:\n            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n            cache_position = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n            )\n\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0) + 1  # Paligemma positions are 1-indexed\n\n        # Merge text and images\n        if pixel_values is not None:\n            image_features = self.get_image_features(pixel_values)\n\n            if input_ids is None:\n                special_image_mask = inputs_embeds == self.get_input_embeddings()(\n                    torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)\n                )\n            else:\n                special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)\n                special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)\n\n            if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():\n                image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]\n                raise ValueError(\n                    f\"Number of images does not match number of special image tokens in the input text. \"\n                    f\"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} \"\n                    \"tokens from image embeddings.\"\n                )\n            image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)\n            inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)\n\n        causal_mask = self._update_causal_mask(\n            attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training\n        )\n        outputs = self.language_model(\n            attention_mask=causal_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=True,\n            cache_position=cache_position,\n            **kwargs,\n        )\n\n        return PaligemmaModelOutputWithPast(\n            last_hidden_state=outputs.last_hidden_state,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            image_hidden_states=image_features if pixel_values is not None else None,\n        )\n\n\nclass KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...\n\n\n@auto_docstring(\n    custom_intro=\"\"\"\n    The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,\n    \"\"\"\n)\nclass PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin):\n    _checkpoint_conversion_mapping = {\n        \"^language_model.model\": \"model.language_model\",\n        \"^vision_tower\": \"model.vision_tower\",\n        \"^multi_modal_projector\": \"model.multi_modal_projector\",\n        \"^language_model.lm_head\": \"lm_head\",\n    }\n    _tied_weights_keys = [\"lm_head.weight\"]\n\n    def __init__(self, config: PaliGemmaConfig):\n        super().__init__(config)\n        self.model = PaliGemmaModel(config)\n        self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.get_input_embeddings()\n\n    def set_input_embeddings(self, value):\n        self.model.set_input_embeddings(value)\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model.set_decoder(decoder)\n\n    def get_decoder(self):\n        return self.model.get_decoder()\n\n    def get_image_features(self, pixel_values):\n        return self.model.get_image_features(pixel_values)\n\n    # Make modules available throught conditional class for BC\n    @property\n    def language_model(self):\n        return self.model.language_model\n\n    @property\n    def vision_tower(self):\n        return self.model.vision_tower\n\n    @property\n    def multi_modal_projector(self):\n        return self.model.multi_modal_projector\n\n    @can_return_tuple\n    @auto_docstring\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        pixel_values: torch.FloatTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,\n        token_type_ids: Optional[torch.LongTensor] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        **kwargs: Unpack[KwargsForCausalLM],\n    ) -> Union[tuple, PaliGemmaCausalLMOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n            config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.\n\n        Example:\n\n        ```python\n        >>> from PIL import Image\n        >>> import requests\n        >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration\n\n        >>> model = PaliGemmaForConditionalGeneration.from_pretrained(\"google/paligemma2-3b-mix-224\")\n        >>> processor = AutoProcessor.from_pretrained(\"google/paligemma2-3b-mix-224\")\n\n        >>> prompt = \"Where is the cat standing?\"\n        >>> url = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg\"\n        >>> image = Image.open(requests.get(url, stream=True).raw)\n\n        >>> inputs = processor(images=image, text=prompt,  return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(**inputs,)\n        >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Where is the cat standing?\\nsnow\"\n        ```\"\"\"\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.model(\n            input_ids=input_ids,\n            pixel_values=pixel_values,\n            token_type_ids=token_type_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            labels=labels,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=True,\n            cache_position=cache_position,\n            **kwargs,\n        )\n\n        hidden_states = outputs[0]\n        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss\n        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep\n        logits = self.lm_head(hidden_states[:, slice_indices, :])\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(\n                logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs\n            )\n\n        return PaliGemmaCausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            image_hidden_states=outputs.image_hidden_states,\n        )\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids,\n        past_key_values=None,\n        inputs_embeds=None,\n        cache_position=None,\n        position_ids=None,\n        pixel_values=None,\n        attention_mask=None,\n        token_type_ids=None,\n        use_cache=True,\n        logits_to_keep=None,\n        labels=None,\n        **kwargs,\n    ):\n        # Overwritten -- custom `position_ids` and `pixel_values` handling\n        model_inputs = super().prepare_inputs_for_generation(\n            input_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            cache_position=cache_position,\n            use_cache=use_cache,\n            logits_to_keep=logits_to_keep,\n            token_type_ids=token_type_ids,\n            **kwargs,\n        )\n\n        # position_ids in Paligemma are 1-indexed\n        if model_inputs.get(\"position_ids\") is not None:\n            model_inputs[\"position_ids\"] += 1\n        # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore\n        # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always\n        if cache_position[0] == 0:\n            model_inputs[\"pixel_values\"] = pixel_values\n        is_training = token_type_ids is not None and labels is not None\n        if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):\n            input_tensor = inputs_embeds if inputs_embeds is not None else input_ids\n            causal_mask = self.model._update_causal_mask(\n                attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training\n            )\n            model_inputs[\"attention_mask\"] = causal_mask\n\n        return model_inputs\n\n    @staticmethod\n    # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position\n    def _prepare_4d_causal_attention_mask_with_cache_position(\n        attention_mask: torch.Tensor,\n        sequence_length: int,\n        target_length: int,\n        dtype: torch.dtype,\n        cache_position: torch.Tensor,\n        batch_size: int,\n        **kwargs,\n    ):\n        \"\"\"\n        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape\n        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.\n\n        Args:\n            attention_mask (`torch.Tensor`):\n                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape\n                `(batch_size, 1, query_length, key_value_length)`.\n            sequence_length (`int`):\n                The sequence length being processed.\n            target_length (`int`):\n                The target length: when generating with static cache, the mask should be as long as the static cache,\n                to account for the 0 padding, the part of the cache that is not filled yet.\n            dtype (`torch.dtype`):\n                The dtype to use for the 4D attention mask.\n            cache_position (`torch.Tensor`):\n                Indices depicting the position of the input sequence tokens in the sequence.\n            batch_size (`torch.Tensor`):\n                Batch size.\n        \"\"\"\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.\n            causal_mask = attention_mask\n        else:\n            min_dtype = torch.finfo(dtype).min\n            causal_mask = torch.full(\n                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device\n            )\n            if sequence_length != 1:\n                causal_mask = torch.triu(causal_mask, diagonal=1)\n            causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)\n            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)\n            if attention_mask is not None:\n                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit\n                mask_length = attention_mask.shape[-1]\n                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(\n                    causal_mask.device\n                )\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(\n                    padding_mask, min_dtype\n                )\n\n        return causal_mask\n\n\n__all__ = [\"PaliGemmaForConditionalGeneration\", \"PaliGemmaPreTrainedModel\", \"PaliGemmaModel\"]\n"
  },
  {
    "path": "src/openpi/models_pytorch/transformers_replace/models/siglip/check.py",
    "content": "import transformers\n\ndef check_whether_transformers_replace_is_installed_correctly():\n    return transformers.__version__ == \"4.53.2\""
  },
  {
    "path": "src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py",
    "content": "# coding=utf-8\n# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Siglip model.\"\"\"\n\nimport math\nimport warnings\nfrom dataclasses import dataclass\nfrom typing import Any, Callable, Optional, Union\n\nimport numpy as np\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\nfrom torch.nn.init import _calculate_fan_in_and_fan_out\n\nfrom ...activations import ACT2FN\nfrom ...modeling_attn_mask_utils import _prepare_4d_attention_mask\nfrom ...modeling_layers import GradientCheckpointingLayer\nfrom ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput\nfrom ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel\nfrom ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int\nfrom .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig\n\n\nlogger = logging.get_logger(__name__)\n\n\ndef _trunc_normal_(tensor, mean, std, a, b):\n    # Cut & paste from PyTorch official master until it's in a few official releases - RW\n    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf\n    def norm_cdf(x):\n        # Computes standard normal cumulative distribution function\n        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0\n\n    if (mean < a - 2 * std) or (mean > b + 2 * std):\n        warnings.warn(\n            \"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. \"\n            \"The distribution of values may be incorrect.\",\n            stacklevel=2,\n        )\n\n    # Values are generated by using a truncated uniform distribution and\n    # then using the inverse CDF for the normal distribution.\n    # Get upper and lower cdf values\n    l = norm_cdf((a - mean) / std)\n    u = norm_cdf((b - mean) / std)\n\n    # Uniformly fill tensor with values from [l, u], then translate to\n    # [2l-1, 2u-1].\n    tensor.uniform_(2 * l - 1, 2 * u - 1)\n\n    # Use inverse cdf transform for normal distribution to get truncated\n    # standard normal\n    tensor.erfinv_()\n\n    # Transform to proper mean, std\n    tensor.mul_(std * math.sqrt(2.0))\n    tensor.add_(mean)\n\n    # Clamp to ensure it's in the proper range\n    tensor.clamp_(min=a, max=b)\n\n\ndef trunc_normal_tf_(\n    tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0\n) -> torch.Tensor:\n    \"\"\"Fills the input Tensor with values drawn from a truncated\n    normal distribution. The values are effectively drawn from the\n    normal distribution :math:`\\\\mathcal{N}(\\text{mean}, \\text{std}^2)`\n    with values outside :math:`[a, b]` redrawn until they are within\n    the bounds. The method used for generating the random values works\n    best when :math:`a \\\\leq \\text{mean} \\\\leq b`.\n\n    NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the\n    bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0\n    and the result is subsequently scaled and shifted by the mean and std args.\n\n    Args:\n        tensor: an n-dimensional `torch.Tensor`\n        mean: the mean of the normal distribution\n        std: the standard deviation of the normal distribution\n        a: the minimum cutoff value\n        b: the maximum cutoff value\n    \"\"\"\n    with torch.no_grad():\n        _trunc_normal_(tensor, 0, 1.0, a, b)\n        tensor.mul_(std).add_(mean)\n\n\ndef variance_scaling_(tensor, scale=1.0, mode=\"fan_in\", distribution=\"normal\"):\n    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)\n    if mode == \"fan_in\":\n        denom = fan_in\n    elif mode == \"fan_out\":\n        denom = fan_out\n    elif mode == \"fan_avg\":\n        denom = (fan_in + fan_out) / 2\n\n    variance = scale / denom\n\n    if distribution == \"truncated_normal\":\n        # constant is stddev of standard normal truncated to (-2, 2)\n        trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)\n    elif distribution == \"normal\":\n        with torch.no_grad():\n            tensor.normal_(std=math.sqrt(variance))\n    elif distribution == \"uniform\":\n        bound = math.sqrt(3 * variance)\n        with torch.no_grad():\n            tensor.uniform_(-bound, bound)\n    else:\n        raise ValueError(f\"invalid distribution {distribution}\")\n\n\ndef lecun_normal_(tensor):\n    variance_scaling_(tensor, mode=\"fan_in\", distribution=\"truncated_normal\")\n\n\ndef default_flax_embed_init(tensor):\n    variance_scaling_(tensor, mode=\"fan_in\", distribution=\"normal\")\n\n\n@dataclass\n@auto_docstring(\n    custom_intro=\"\"\"\n    Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.\n    \"\"\"\n)\n# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip\nclass SiglipVisionModelOutput(ModelOutput):\n    r\"\"\"\n    image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):\n        The image embeddings obtained by applying the projection layer to the pooler_output.\n    \"\"\"\n\n    image_embeds: Optional[torch.FloatTensor] = None\n    last_hidden_state: Optional[torch.FloatTensor] = None\n    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None\n    attentions: Optional[tuple[torch.FloatTensor, ...]] = None\n\n\n@dataclass\n@auto_docstring(\n    custom_intro=\"\"\"\n    Base class for text model's outputs that also contains a pooling of the last hidden states.\n    \"\"\"\n)\n# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip\nclass SiglipTextModelOutput(ModelOutput):\n    r\"\"\"\n    text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):\n        The text embeddings obtained by applying the projection layer to the pooler_output.\n    \"\"\"\n\n    text_embeds: Optional[torch.FloatTensor] = None\n    last_hidden_state: Optional[torch.FloatTensor] = None\n    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None\n    attentions: Optional[tuple[torch.FloatTensor, ...]] = None\n\n\n@dataclass\n@auto_docstring\n# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip\nclass SiglipOutput(ModelOutput):\n    r\"\"\"\n    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):\n        Contrastive loss for image-text similarity.\n    logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):\n        The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text\n        similarity scores.\n    logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):\n        The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image\n        similarity scores.\n    text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):\n        The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].\n    image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):\n        The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].\n    text_model_output (`BaseModelOutputWithPooling`):\n        The output of the [`SiglipTextModel`].\n    vision_model_output (`BaseModelOutputWithPooling`):\n        The output of the [`SiglipVisionModel`].\n    \"\"\"\n\n    loss: Optional[torch.FloatTensor] = None\n    logits_per_image: Optional[torch.FloatTensor] = None\n    logits_per_text: Optional[torch.FloatTensor] = None\n    text_embeds: Optional[torch.FloatTensor] = None\n    image_embeds: Optional[torch.FloatTensor] = None\n    text_model_output: BaseModelOutputWithPooling = None\n    vision_model_output: BaseModelOutputWithPooling = None\n\n    def to_tuple(self) -> tuple[Any]:\n        return tuple(\n            self[k] if k not in [\"text_model_output\", \"vision_model_output\"] else getattr(self, k).to_tuple()\n            for k in self.keys()\n        )\n\n\nclass SiglipVisionEmbeddings(nn.Module):\n    def __init__(self, config: SiglipVisionConfig):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n\n        self.patch_embedding = nn.Conv2d(\n            in_channels=config.num_channels,\n            out_channels=self.embed_dim,\n            kernel_size=self.patch_size,\n            stride=self.patch_size,\n            padding=\"valid\",\n        )\n\n        self.num_patches = (self.image_size // self.patch_size) ** 2\n        self.num_positions = self.num_patches\n        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)\n        self.register_buffer(\"position_ids\", torch.arange(self.num_positions).expand((1, -1)), persistent=False)\n\n    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:\n        \"\"\"\n        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution\n        images. This method is also adapted to support torch.jit tracing and no class embeddings.\n\n        Adapted from:\n        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and\n        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211\n        \"\"\"\n\n        num_patches = embeddings.shape[1]\n        num_positions = self.position_embedding.weight.shape[0]\n\n        # always interpolate when tracing to ensure the exported model works for dynamic input shapes\n        if not torch.jit.is_tracing() and num_patches == num_positions and height == width:\n            return self.position_embedding(self.position_ids)\n\n        patch_pos_embed = self.position_embedding.weight.unsqueeze(0)\n\n        dim = embeddings.shape[-1]\n\n        new_height = height // self.patch_size\n        new_width = width // self.patch_size\n\n        sqrt_num_positions = torch_int(num_positions**0.5)\n        patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)\n        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)\n\n        patch_pos_embed = nn.functional.interpolate(\n            patch_pos_embed,\n            size=(new_height, new_width),\n            mode=\"bicubic\",\n            align_corners=False,\n        )\n\n        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)\n        return patch_pos_embed\n\n    def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:\n        _, _, height, width = pixel_values.shape\n        target_dtype = self.patch_embedding.weight.dtype\n        patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))  # shape = [*, width, grid, grid]\n        embeddings = patch_embeds.flatten(2).transpose(1, 2)\n\n        if interpolate_pos_encoding:\n            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)\n        else:\n            embeddings = embeddings + self.position_embedding(self.position_ids)\n        return embeddings\n\n\n# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip\nclass SiglipTextEmbeddings(nn.Module):\n    def __init__(self, config: SiglipTextConfig):\n        super().__init__()\n        embed_dim = config.hidden_size\n\n        self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)\n        self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)\n\n        # position_ids (1, len position emb) is contiguous in memory and exported when serialized\n        self.register_buffer(\n            \"position_ids\", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False\n        )\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ) -> torch.Tensor:\n        seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]\n        max_position_embedding = self.position_embedding.weight.shape[0]\n\n        if seq_length > max_position_embedding:\n            raise ValueError(\n                f\"Sequence length must be less than max_position_embeddings (got `sequence length`: \"\n                f\"{seq_length} and max_position_embeddings: {max_position_embedding}\"\n            )\n\n        if position_ids is None:\n            position_ids = self.position_ids[:, :seq_length]\n\n        if inputs_embeds is None:\n            inputs_embeds = self.token_embedding(input_ids)\n\n        position_embeddings = self.position_embedding(position_ids)\n        embeddings = inputs_embeds + position_embeddings\n\n        return embeddings\n\n\ndef eager_attention_forward(\n    module: nn.Module,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attention_mask: Optional[torch.Tensor],\n    scaling: float,\n    dropout: float = 0.0,\n    **kwargs,\n):\n    attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling\n    if attention_mask is not None:\n        attn_weights = attn_weights + attention_mask\n\n    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)\n    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)\n\n    attn_output = torch.matmul(attn_weights, value)\n    attn_output = attn_output.transpose(1, 2).contiguous()\n\n    return attn_output, attn_weights\n\n\nclass SiglipAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.embed_dim // self.num_heads\n        if self.head_dim * self.num_heads != self.embed_dim:\n            raise ValueError(\n                f\"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:\"\n                f\" {self.num_heads}).\"\n            )\n        self.scale = self.head_dim**-0.5\n        self.dropout = config.attention_dropout\n        self.is_causal = False\n\n        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)\n        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)\n        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)\n        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        output_attentions: Optional[bool] = False,\n    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:\n        \"\"\"Input shape: Batch x Time x Channel\"\"\"\n\n        batch_size, seq_length, embed_dim = hidden_states.shape\n\n        queries = self.q_proj(hidden_states)\n        keys = self.k_proj(hidden_states)\n        values = self.v_proj(hidden_states)\n\n        queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)\n        keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)\n        values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)\n\n        attention_interface: Callable = eager_attention_forward\n        if self.config._attn_implementation != \"eager\":\n            if self.config._attn_implementation == \"sdpa\" and output_attentions:\n                logger.warning_once(\n                    \"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to \"\n                    'eager attention. This warning can be removed using the argument `attn_implementation=\"eager\"` when loading the model.'\n                )\n            else:\n                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]\n\n        attn_output, attn_weights = attention_interface(\n            self,\n            queries,\n            keys,\n            values,\n            attention_mask,\n            is_causal=self.is_causal,\n            scaling=self.scale,\n            dropout=0.0 if not self.training else self.dropout,\n        )\n\n        attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()\n        attn_output = self.out_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights\n\n\n# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip\nclass SiglipMLP(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.activation_fn = ACT2FN[config.hidden_act]\n        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)\n        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass SiglipEncoderLayer(GradientCheckpointingLayer):\n    def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]):\n        super().__init__()\n        self.embed_dim = config.hidden_size\n        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)\n        self.self_attn = SiglipAttention(config)\n        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)\n        self.mlp = SiglipMLP(config)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor,\n        output_attentions: Optional[bool] = False,\n    ) -> tuple[torch.FloatTensor]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`):\n                Input to the layer of shape `(batch, seq_len, embed_dim)`.\n            attention_mask (`torch.FloatTensor`):\n                Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.\n            output_attentions (`bool`, *optional*, defaults to `False`):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n        \"\"\"\n        residual = hidden_states\n\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states, attn_weights = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            output_attentions=output_attentions,\n        )\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (attn_weights,)\n\n        return outputs\n\n\n@auto_docstring\nclass SiglipPreTrainedModel(PreTrainedModel):\n    config_class = SiglipConfig\n    base_model_prefix = \"siglip\"\n    supports_gradient_checkpointing = True\n\n    _no_split_modules = [\n        \"SiglipTextEmbeddings\",\n        \"SiglipEncoderLayer\",\n        \"SiglipVisionEmbeddings\",\n        \"SiglipEncoderLayer\",\n        \"SiglipMultiheadAttentionPoolingHead\",\n    ]\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _supports_flex_attn = True\n    _supports_attention_backend = True\n\n    def _init_weights(self, module):\n        \"\"\"Initialize the weights\"\"\"\n        if isinstance(module, SiglipVisionEmbeddings):\n            width = (\n                self.config.vision_config.hidden_size\n                if isinstance(self.config, SiglipConfig)\n                else self.config.hidden_size\n            )\n            nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))\n        elif isinstance(module, nn.Embedding):\n            default_flax_embed_init(module.weight)\n        elif isinstance(module, SiglipAttention):\n            nn.init.xavier_uniform_(module.q_proj.weight)\n            nn.init.xavier_uniform_(module.k_proj.weight)\n            nn.init.xavier_uniform_(module.v_proj.weight)\n            nn.init.xavier_uniform_(module.out_proj.weight)\n            nn.init.zeros_(module.q_proj.bias)\n            nn.init.zeros_(module.k_proj.bias)\n            nn.init.zeros_(module.v_proj.bias)\n            nn.init.zeros_(module.out_proj.bias)\n        elif isinstance(module, SiglipMLP):\n            nn.init.xavier_uniform_(module.fc1.weight)\n            nn.init.xavier_uniform_(module.fc2.weight)\n            nn.init.normal_(module.fc1.bias, std=1e-6)\n            nn.init.normal_(module.fc2.bias, std=1e-6)\n        elif isinstance(module, SiglipMultiheadAttentionPoolingHead):\n            nn.init.xavier_uniform_(module.probe.data)\n            nn.init.xavier_uniform_(module.attention.in_proj_weight.data)\n            nn.init.zeros_(module.attention.in_proj_bias.data)\n        elif isinstance(module, SiglipModel):\n            logit_scale_init = torch.log(torch.tensor(1.0))\n            module.logit_scale.data.fill_(logit_scale_init)\n            module.logit_bias.data.zero_()\n        elif isinstance(module, SiglipForImageClassification):\n            nn.init.normal_(\n                module.classifier.weight,\n                std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,\n            )\n        elif isinstance(module, (nn.Linear, nn.Conv2d)):\n            lecun_normal_(module.weight)\n            if module.bias is not None:\n                nn.init.zeros_(module.bias)\n        elif isinstance(module, nn.LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n\n\n# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip\nclass SiglipEncoder(nn.Module):\n    \"\"\"\n    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a\n    [`SiglipEncoderLayer`].\n\n    Args:\n        config: SiglipConfig\n    \"\"\"\n\n    def __init__(self, config: SiglipConfig):\n        super().__init__()\n        self.config = config\n        self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])\n        self.gradient_checkpointing = False\n\n    # Ignore copy\n    @can_return_tuple\n    def forward(\n        self,\n        inputs_embeds,\n        attention_mask: Optional[torch.Tensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n    ) -> BaseModelOutput:\n        r\"\"\"\n        Args:\n            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):\n                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.\n                This is useful if you want more control over how to convert `input_ids` indices into associated vectors\n                than the model's internal embedding lookup matrix.\n            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n                - 1 for tokens that are **not masked**,\n                - 0 for tokens that are **masked**.\n\n                [What are attention masks?](../glossary#attention-mask)\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            output_hidden_states (`bool`, *optional*):\n                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors\n                for more detail.\n            return_dict (`bool`, *optional*):\n                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        \"\"\"\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n\n        encoder_states = () if output_hidden_states else None\n        all_attentions = () if output_attentions else None\n\n        hidden_states = inputs_embeds\n        for encoder_layer in self.layers:\n            if output_hidden_states:\n                encoder_states = encoder_states + (hidden_states,)\n\n            layer_outputs = encoder_layer(\n                hidden_states,\n                attention_mask,\n                output_attentions=output_attentions,\n            )\n\n            hidden_states = layer_outputs[0]\n\n            if output_attentions:\n                all_attentions = all_attentions + (layer_outputs[1],)\n\n        if output_hidden_states:\n            encoder_states = encoder_states + (hidden_states,)\n\n        return BaseModelOutput(\n            last_hidden_state=hidden_states,\n            hidden_states=encoder_states,\n            attentions=all_attentions,\n        )\n\n\nclass SiglipTextTransformer(nn.Module):\n    def __init__(self, config: SiglipTextConfig):\n        super().__init__()\n        self.config = config\n        embed_dim = config.hidden_size\n        self.embeddings = SiglipTextEmbeddings(config)\n        self.encoder = SiglipEncoder(config)\n        self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)\n\n        self.head = nn.Linear(embed_dim, config.projection_size)\n        self._use_flash_attention_2 = config._attn_implementation == \"flash_attention_2\"\n\n    @can_return_tuple\n    @auto_docstring\n    def forward(\n        self,\n        input_ids: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.Tensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n    ) -> BaseModelOutputWithPooling:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n\n        if input_ids is None:\n            raise ValueError(\"You have to specify input_ids\")\n\n        input_shape = input_ids.size()\n        input_ids = input_ids.view(-1, input_shape[-1])\n\n        hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)\n\n        # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.\n        # expand attention_mask\n        if attention_mask is not None and not self._use_flash_attention_2:\n            # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]\n            attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)\n\n        encoder_outputs: BaseModelOutput = self.encoder(\n            inputs_embeds=hidden_states,\n            attention_mask=attention_mask,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n        )\n\n        last_hidden_state = encoder_outputs.last_hidden_state\n        last_hidden_state = self.final_layer_norm(last_hidden_state)\n\n        # Assuming \"sticky\" EOS tokenization, last token is always EOS.\n        pooled_output = last_hidden_state[:, -1, :]\n        pooled_output = self.head(pooled_output)\n\n        return BaseModelOutputWithPooling(\n            last_hidden_state=last_hidden_state,\n            pooler_output=pooled_output,\n            hidden_states=encoder_outputs.hidden_states,\n            attentions=encoder_outputs.attentions,\n        )\n\n\n@auto_docstring(\n    custom_intro=\"\"\"\n    The text model from SigLIP without any head or projection on top.\n    \"\"\"\n)\nclass SiglipTextModel(SiglipPreTrainedModel):\n    config_class = SiglipTextConfig\n\n    def __init__(self, config: SiglipTextConfig):\n        super().__init__(config)\n        self.text_model = SiglipTextTransformer(config)\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self) -> nn.Module:\n        return self.text_model.embeddings.token_embedding\n\n    def set_input_embeddings(self, value):\n        self.text_model.embeddings.token_embedding = value\n\n    @can_return_tuple\n    @auto_docstring\n    def forward(\n        self,\n        input_ids: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.Tensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n    ) -> BaseModelOutputWithPooling:\n        r\"\"\"\n        Examples:\n\n        ```python\n        >>> from transformers import AutoTokenizer, SiglipTextModel\n\n        >>> model = SiglipTextModel.from_pretrained(\"google/siglip-base-patch16-224\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"google/siglip-base-patch16-224\")\n\n        >>> # important: make sure to set padding=\"max_length\" as that's how the model was trained\n        >>> inputs = tokenizer([\"a photo of a cat\", \"a photo of a dog\"], padding=\"max_length\", return_tensors=\"pt\")\n\n        >>> outputs = model(**inputs)\n        >>> last_hidden_state = outputs.last_hidden_state\n        >>> pooled_output = outputs.pooler_output  # pooled (EOS token) states\n        ```\"\"\"\n\n        return self.text_model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n        )\n\n\nclass SiglipVisionTransformer(nn.Module):\n    def __init__(self, config: SiglipVisionConfig):\n        super().__init__()\n        self.config = config\n        embed_dim = config.hidden_size\n\n        self.embeddings = SiglipVisionEmbeddings(config)\n        self.encoder = SiglipEncoder(config)\n        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)\n        self.use_head = True if not hasattr(config, \"vision_use_head\") else config.vision_use_head\n        if self.use_head:\n            self.head = SiglipMultiheadAttentionPoolingHead(config)\n\n    @can_return_tuple\n    @auto_docstring\n    def forward(\n        self,\n        pixel_values,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        interpolate_pos_encoding: Optional[bool] = False,\n    ) -> BaseModelOutputWithPooling:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n\n        hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)\n        # Convert to bfloat16 if the encoder uses bfloat16\n        if len(self.encoder.layers) > 0 and self.encoder.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:\n            hidden_states = hidden_states.to(torch.bfloat16)\n\n        encoder_outputs: BaseModelOutput = self.encoder(\n            inputs_embeds=hidden_states,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n        )\n\n        last_hidden_state = encoder_outputs.last_hidden_state\n        last_hidden_state = self.post_layernorm(last_hidden_state)\n\n        pooler_output = self.head(last_hidden_state) if self.use_head else None\n\n        return BaseModelOutputWithPooling(\n            last_hidden_state=last_hidden_state,\n            pooler_output=pooler_output,\n            hidden_states=encoder_outputs.hidden_states,\n            attentions=encoder_outputs.attentions,\n        )\n\n\nclass SiglipMultiheadAttentionPoolingHead(nn.Module):\n    \"\"\"Multihead Attention Pooling.\"\"\"\n\n    def __init__(self, config: SiglipVisionConfig):\n        super().__init__()\n\n        self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))\n        self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)\n        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.mlp = SiglipMLP(config)\n\n    def forward(self, hidden_state):\n        batch_size = hidden_state.shape[0]\n        probe = self.probe.repeat(batch_size, 1, 1)\n\n        hidden_state = self.attention(probe, hidden_state, hidden_state)[0]\n\n        residual = hidden_state\n        hidden_state = self.layernorm(hidden_state)\n        hidden_state = residual + self.mlp(hidden_state)\n\n        return hidden_state[:, 0]\n\n\n@auto_docstring(\n    custom_intro=\"\"\"\n    The vision model from SigLIP without any head or projection on top.\n    \"\"\"\n)\nclass SiglipVisionModel(SiglipPreTrainedModel):\n    config_class = SiglipVisionConfig\n    main_input_name = \"pixel_values\"\n\n    def __init__(self, config: SiglipVisionConfig):\n        super().__init__(config)\n\n        self.vision_model = SiglipVisionTransformer(config)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self) -> nn.Module:\n        return self.vision_model.embeddings.patch_embedding\n\n    @can_return_tuple\n    @auto_docstring\n    def forward(\n        self,\n        pixel_values,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        interpolate_pos_encoding: bool = False,\n    ) -> BaseModelOutputWithPooling:\n        r\"\"\"\n        Examples:\n\n        ```python\n        >>> from PIL import Image\n        >>> import requests\n        >>> from transformers import AutoProcessor, SiglipVisionModel\n\n        >>> model = SiglipVisionModel.from_pretrained(\"google/siglip-base-patch16-224\")\n        >>> processor = AutoProcessor.from_pretrained(\"google/siglip-base-patch16-224\")\n\n        >>> url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n        >>> image = Image.open(requests.get(url, stream=True).raw)\n\n        >>> inputs = processor(images=image, return_tensors=\"pt\")\n\n        >>> outputs = model(**inputs)\n        >>> last_hidden_state = outputs.last_hidden_state\n        >>> pooled_output = outputs.pooler_output  # pooled features\n        ```\"\"\"\n\n        return self.vision_model(\n            pixel_values=pixel_values,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            interpolate_pos_encoding=interpolate_pos_encoding,\n        )\n\n\n@auto_docstring\nclass SiglipModel(SiglipPreTrainedModel):\n    config_class = SiglipConfig\n\n    def __init__(self, config: SiglipConfig):\n        super().__init__(config)\n\n        if not isinstance(config.text_config, SiglipTextConfig):\n            raise TypeError(\n                \"config.text_config is expected to be of type SiglipTextConfig but is of type\"\n                f\" {type(config.text_config)}.\"\n            )\n\n        if not isinstance(config.vision_config, SiglipVisionConfig):\n            raise TypeError(\n                \"config.vision_config is expected to be of type SiglipVisionConfig but is of type\"\n                f\" {type(config.vision_config)}.\"\n            )\n\n        text_config = config.text_config\n        vision_config = config.vision_config\n\n        # First, initialize the text and vision models with proper attention implementation\n        text_model = SiglipTextModel._from_config(text_config)\n        vision_model = SiglipVisionModel._from_config(vision_config)\n\n        # Second, get the text and vision submodules (for backward compatibility)\n        self.text_model = text_model.text_model\n        self.vision_model = vision_model.vision_model\n\n        self.logit_scale = nn.Parameter(torch.randn(1))\n        self.logit_bias = nn.Parameter(torch.randn(1))\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    @auto_docstring\n    def get_text_features(\n        self,\n        input_ids: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.Tensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n    ) -> torch.FloatTensor:\n        r\"\"\"\n        Returns:\n            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by\n            applying the projection layer to the pooled output of [`SiglipTextModel`].\n\n        Examples:\n\n        ```python\n        >>> from transformers import AutoTokenizer, AutoModel\n        >>> import torch\n\n        >>> model = AutoModel.from_pretrained(\"google/siglip-base-patch16-224\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"google/siglip-base-patch16-224\")\n\n        >>> # important: make sure to set padding=\"max_length\" as that's how the model was trained\n        >>> inputs = tokenizer([\"a photo of a cat\", \"a photo of a dog\"], padding=\"max_length\", return_tensors=\"pt\")\n        >>> with torch.no_grad():\n        ...     text_features = model.get_text_features(**inputs)\n        ```\"\"\"\n        # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n\n        text_outputs: BaseModelOutputWithPooling = self.text_model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n        )\n\n        pooled_output = text_outputs.pooler_output\n\n        return pooled_output\n\n    @auto_docstring\n    def get_image_features(\n        self,\n        pixel_values: Optional[torch.FloatTensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        interpolate_pos_encoding: bool = False,\n    ) -> torch.FloatTensor:\n        r\"\"\"\n        Returns:\n            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by\n            applying the projection layer to the pooled output of [`SiglipVisionModel`].\n\n        Examples:\n\n        ```python\n        >>> from PIL import Image\n        >>> import requests\n        >>> from transformers import AutoProcessor, AutoModel\n        >>> import torch\n\n        >>> model = AutoModel.from_pretrained(\"google/siglip-base-patch16-224\")\n        >>> processor = AutoProcessor.from_pretrained(\"google/siglip-base-patch16-224\")\n\n        >>> url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n        >>> image = Image.open(requests.get(url, stream=True).raw)\n\n        >>> inputs = processor(images=image, return_tensors=\"pt\")\n\n        >>> with torch.no_grad():\n        ...     image_features = model.get_image_features(**inputs)\n        ```\"\"\"\n        # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n\n        vision_outputs: BaseModelOutputWithPooling = self.vision_model(\n            pixel_values=pixel_values,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            interpolate_pos_encoding=interpolate_pos_encoding,\n        )\n\n        pooled_output = vision_outputs.pooler_output\n\n        return pooled_output\n\n    @can_return_tuple\n    @auto_docstring\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        pixel_values: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        return_loss: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        interpolate_pos_encoding: bool = False,\n    ) -> SiglipOutput:\n        r\"\"\"\n        return_loss (`bool`, *optional*):\n            Whether or not to return the contrastive loss.\n\n        Examples:\n\n        ```python\n        >>> from PIL import Image\n        >>> import requests\n        >>> from transformers import AutoProcessor, AutoModel\n        >>> import torch\n\n        >>> model = AutoModel.from_pretrained(\"google/siglip-base-patch16-224\")\n        >>> processor = AutoProcessor.from_pretrained(\"google/siglip-base-patch16-224\")\n\n        >>> url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n        >>> image = Image.open(requests.get(url, stream=True).raw)\n\n        >>> texts = [\"a photo of 2 cats\", \"a photo of 2 dogs\"]\n        >>> # important: we pass `padding=max_length` since the model was trained with this\n        >>> inputs = processor(text=texts, images=image, padding=\"max_length\", return_tensors=\"pt\")\n\n        >>> with torch.no_grad():\n        ...     outputs = model(**inputs)\n\n        >>> logits_per_image = outputs.logits_per_image\n        >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities\n        >>> print(f\"{probs[0][0]:.1%} that image 0 is '{texts[0]}'\")\n        31.9% that image 0 is 'a photo of 2 cats'\n        ```\"\"\"\n        # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n\n        vision_outputs: BaseModelOutputWithPooling = self.vision_model(\n            pixel_values=pixel_values,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            interpolate_pos_encoding=interpolate_pos_encoding,\n        )\n\n        text_outputs: BaseModelOutputWithPooling = self.text_model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n        )\n\n        image_embeds = vision_outputs.pooler_output\n        text_embeds = text_outputs.pooler_output\n\n        # normalized features\n        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)\n        text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)\n\n        # cosine similarity as logits\n        logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))\n\n        logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device)\n        logits_per_text = logits_per_text * logit_scale.exp() + logit_bias\n\n        logits_per_image = logits_per_text.t()\n\n        loss = None\n        if return_loss:\n            # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287\n            eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)\n            m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye\n            loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)\n            nll = -torch.sum(loglik, dim=-1)\n            loss = nll.mean()\n\n        return SiglipOutput(\n            loss=loss,\n            logits_per_image=logits_per_image,\n            logits_per_text=logits_per_text,\n            text_embeds=text_embeds,\n            image_embeds=image_embeds,\n            text_model_output=text_outputs,\n            vision_model_output=vision_outputs,\n        )\n\n\n@auto_docstring(\n    custom_intro=\"\"\"\n    SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of\n    the patch tokens) e.g. for ImageNet.\n    \"\"\"\n)\nclass SiglipForImageClassification(SiglipPreTrainedModel):\n    main_input_name = \"pixel_values\"\n\n    def __init__(self, config: SiglipConfig) -> None:\n        super().__init__(config)\n\n        self.num_labels = config.num_labels\n\n        # Create the vision model with proper attention\n        # and take only vision_model submodule (for backward compatibility)\n        vision_model = SiglipVisionModel._from_config(config.vision_config)\n        self.vision_model = vision_model.vision_model\n\n        # Classifier head\n        self.classifier = (\n            nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()\n        )\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    @can_return_tuple\n    @auto_docstring\n    def forward(\n        self,\n        pixel_values: Optional[torch.Tensor] = None,\n        labels: Optional[torch.Tensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        interpolate_pos_encoding: bool = False,\n    ) -> ImageClassifierOutput:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n\n        Examples:\n\n        ```python\n        >>> from transformers import AutoImageProcessor, SiglipForImageClassification\n        >>> import torch\n        >>> from PIL import Image\n        >>> import requests\n\n        >>> torch.manual_seed(3)  # doctest: +IGNORE_RESULT\n        >>> url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n        >>> image = Image.open(requests.get(url, stream=True).raw)\n\n        >>> # note: we are loading a `SiglipModel` from the hub here,\n        >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.\n        >>> image_processor = AutoImageProcessor.from_pretrained(\"google/siglip-base-patch16-224\")\n        >>> model = SiglipForImageClassification.from_pretrained(\"google/siglip-base-patch16-224\")\n\n        >>> inputs = image_processor(images=image, return_tensors=\"pt\")\n        >>> outputs = model(**inputs)\n        >>> logits = outputs.logits\n        >>> # model predicts one of the two classes\n        >>> predicted_class_idx = logits.argmax(-1).item()\n        >>> print(\"Predicted class:\", model.config.id2label[predicted_class_idx])\n        Predicted class: LABEL_1\n        ```\"\"\"\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n\n        outputs: BaseModelOutputWithPooling = self.vision_model(\n            pixel_values,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            interpolate_pos_encoding=interpolate_pos_encoding,\n        )\n\n        sequence_output = outputs.last_hidden_state\n\n        # average pool the patch tokens\n        sequence_output = torch.mean(sequence_output, dim=1)\n        # apply classifier\n        logits = self.classifier(sequence_output)\n\n        loss = None\n        if labels is not None:\n            # move labels to correct device to enable model parallelism\n            labels = labels.to(logits.device)\n            if self.config.problem_type is None:\n                if self.num_labels == 1:\n                    self.config.problem_type = \"regression\"\n                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n                    self.config.problem_type = \"single_label_classification\"\n                else:\n                    self.config.problem_type = \"multi_label_classification\"\n\n            if self.config.problem_type == \"regression\":\n                loss_fct = MSELoss()\n                if self.num_labels == 1:\n                    loss = loss_fct(logits.squeeze(), labels.squeeze())\n                else:\n                    loss = loss_fct(logits, labels)\n            elif self.config.problem_type == \"single_label_classification\":\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n            elif self.config.problem_type == \"multi_label_classification\":\n                loss_fct = BCEWithLogitsLoss()\n                loss = loss_fct(logits, labels)\n\n        return ImageClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n__all__ = [\n    \"SiglipModel\",\n    \"SiglipPreTrainedModel\",\n    \"SiglipTextModel\",\n    \"SiglipVisionModel\",\n    \"SiglipForImageClassification\",\n]"
  },
  {
    "path": "src/openpi/policies/aloha_policy.py",
    "content": "import dataclasses\nfrom typing import ClassVar\n\nimport einops\nimport numpy as np\n\nfrom openpi import transforms\n\n\ndef make_aloha_example() -> dict:\n    \"\"\"Creates a random input example for the Aloha policy.\"\"\"\n    return {\n        \"state\": np.ones((14,)),\n        \"images\": {\n            \"cam_high\": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),\n            \"cam_low\": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),\n            \"cam_left_wrist\": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),\n            \"cam_right_wrist\": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),\n        },\n        \"prompt\": \"do something\",\n    }\n\n\n@dataclasses.dataclass(frozen=True)\nclass AlohaInputs(transforms.DataTransformFn):\n    \"\"\"Inputs for the Aloha policy.\n\n    Expected inputs:\n    - images: dict[name, img] where img is [channel, height, width]. name must be in EXPECTED_CAMERAS.\n    - state: [14]\n    - actions: [action_horizon, 14]\n    \"\"\"\n\n    # If true, this will convert the joint and gripper values from the standard Aloha space to\n    # the space used by the pi internal runtime which was used to train the base model.\n    adapt_to_pi: bool = True\n\n    # The expected cameras names. All input cameras must be in this set. Missing cameras will be\n    # replaced with black images and the corresponding `image_mask` will be set to False.\n    EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = (\"cam_high\", \"cam_low\", \"cam_left_wrist\", \"cam_right_wrist\")\n\n    def __call__(self, data: dict) -> dict:\n        data = _decode_aloha(data, adapt_to_pi=self.adapt_to_pi)\n\n        in_images = data[\"images\"]\n        if set(in_images) - set(self.EXPECTED_CAMERAS):\n            raise ValueError(f\"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}\")\n\n        # Assume that base image always exists.\n        base_image = in_images[\"cam_high\"]\n\n        images = {\n            \"base_0_rgb\": base_image,\n        }\n        image_masks = {\n            \"base_0_rgb\": np.True_,\n        }\n\n        # Add the extra images.\n        extra_image_names = {\n            \"left_wrist_0_rgb\": \"cam_left_wrist\",\n            \"right_wrist_0_rgb\": \"cam_right_wrist\",\n        }\n        for dest, source in extra_image_names.items():\n            if source in in_images:\n                images[dest] = in_images[source]\n                image_masks[dest] = np.True_\n            else:\n                images[dest] = np.zeros_like(base_image)\n                image_masks[dest] = np.False_\n\n        inputs = {\n            \"image\": images,\n            \"image_mask\": image_masks,\n            \"state\": data[\"state\"],\n        }\n\n        # Actions are only available during training.\n        if \"actions\" in data:\n            actions = np.asarray(data[\"actions\"])\n            actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi)\n            inputs[\"actions\"] = actions\n\n        if \"prompt\" in data:\n            inputs[\"prompt\"] = data[\"prompt\"]\n\n        return inputs\n\n\n@dataclasses.dataclass(frozen=True)\nclass AlohaOutputs(transforms.DataTransformFn):\n    \"\"\"Outputs for the Aloha policy.\"\"\"\n\n    # If true, this will convert the joint and gripper values from the standard Aloha space to\n    # the space used by the pi internal runtime which was used to train the base model.\n    adapt_to_pi: bool = True\n\n    def __call__(self, data: dict) -> dict:\n        # Only return the first 14 dims.\n        actions = np.asarray(data[\"actions\"][:, :14])\n        return {\"actions\": _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)}\n\n\ndef _joint_flip_mask() -> np.ndarray:\n    \"\"\"Used to convert between aloha and pi joint angles.\"\"\"\n    return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])\n\n\ndef _normalize(x, min_val, max_val):\n    return (x - min_val) / (max_val - min_val)\n\n\ndef _unnormalize(x, min_val, max_val):\n    return x * (max_val - min_val) + min_val\n\n\ndef _gripper_to_angular(value):\n    # Aloha transforms the gripper positions into a linear space. The following code\n    # reverses this transformation to be consistent with pi0 which is pretrained in\n    # angular space.\n    #\n    # These values are coming from the Aloha code:\n    # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED\n    value = _unnormalize(value, min_val=0.01844, max_val=0.05800)\n\n    # This is the inverse of the angular to linear transformation inside the Interbotix code.\n    def linear_to_radian(linear_position, arm_length, horn_radius):\n        value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)\n        return np.arcsin(np.clip(value, -1.0, 1.0))\n\n    # The constants are taken from the Interbotix code.\n    value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)\n\n    # pi0 gripper data is normalized (0, 1) between encoder counts (2405, 3110).\n    # There are 4096 total encoder counts and aloha uses a zero of 2048.\n    # Converting this to radians means that the normalized inputs are between (0.5476, 1.6296)\n    return _normalize(value, min_val=0.5476, max_val=1.6296)\n\n\ndef _gripper_from_angular(value):\n    # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.\n    # Note that the units are still angular but the range is different.\n\n    # We do not scale the output since the trossen model predictions are already in radians.\n    # See the comment in _gripper_to_angular for a derivation of the constant\n    value = value + 0.5476\n\n    # These values are coming from the Aloha code:\n    # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE\n    return _normalize(value, min_val=-0.6213, max_val=1.4910)\n\n\ndef _gripper_from_angular_inv(value):\n    # Directly inverts the gripper_from_angular function.\n    value = _unnormalize(value, min_val=-0.6213, max_val=1.4910)\n    return value - 0.5476\n\n\ndef _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict:\n    # state is [left_arm_joint_angles, left_arm_gripper, right_arm_joint_angles, right_arm_gripper]\n    # dim sizes: [6, 1, 6, 1]\n    state = np.asarray(data[\"state\"])\n    state = _decode_state(state, adapt_to_pi=adapt_to_pi)\n\n    def convert_image(img):\n        img = np.asarray(img)\n        # Convert to uint8 if using float images.\n        if np.issubdtype(img.dtype, np.floating):\n            img = (255 * img).astype(np.uint8)\n        # Convert from [channel, height, width] to [height, width, channel].\n        return einops.rearrange(img, \"c h w -> h w c\")\n\n    images = data[\"images\"]\n    images_dict = {name: convert_image(img) for name, img in images.items()}\n\n    data[\"images\"] = images_dict\n    data[\"state\"] = state\n    return data\n\n\ndef _decode_state(state: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:\n    if adapt_to_pi:\n        # Flip the joints.\n        state = _joint_flip_mask() * state\n        # Reverse the gripper transformation that is being applied by the Aloha runtime.\n        state[[6, 13]] = _gripper_to_angular(state[[6, 13]])\n    return state\n\n\ndef _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:\n    if adapt_to_pi:\n        # Flip the joints.\n        actions = _joint_flip_mask() * actions\n        actions[:, [6, 13]] = _gripper_from_angular(actions[:, [6, 13]])\n    return actions\n\n\ndef _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:\n    if adapt_to_pi:\n        actions = _joint_flip_mask() * actions\n        actions[:, [6, 13]] = _gripper_from_angular_inv(actions[:, [6, 13]])\n    return actions\n"
  },
  {
    "path": "src/openpi/policies/droid_policy.py",
    "content": "import dataclasses\n\nimport einops\nimport numpy as np\n\nfrom openpi import transforms\nfrom openpi.models import model as _model\n\n\ndef make_droid_example() -> dict:\n    \"\"\"Creates a random input example for the Droid policy.\"\"\"\n    return {\n        \"observation/exterior_image_1_left\": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),\n        \"observation/wrist_image_left\": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),\n        \"observation/joint_position\": np.random.rand(7),\n        \"observation/gripper_position\": np.random.rand(1),\n        \"prompt\": \"do something\",\n    }\n\n\ndef _parse_image(image) -> np.ndarray:\n    image = np.asarray(image)\n    if np.issubdtype(image.dtype, np.floating):\n        image = (255 * image).astype(np.uint8)\n    if image.shape[0] == 3:\n        image = einops.rearrange(image, \"c h w -> h w c\")\n    return image\n\n\n@dataclasses.dataclass(frozen=True)\nclass DroidInputs(transforms.DataTransformFn):\n    # Determines which model will be used.\n    model_type: _model.ModelType\n\n    def __call__(self, data: dict) -> dict:\n        gripper_pos = np.asarray(data[\"observation/gripper_position\"])\n        if gripper_pos.ndim == 0:\n            # Ensure gripper position is a 1D array, not a scalar, so we can concatenate with joint positions\n            gripper_pos = gripper_pos[np.newaxis]\n        state = np.concatenate([data[\"observation/joint_position\"], gripper_pos])\n\n        # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically\n        # stores as float32 (C,H,W), gets skipped for policy inference\n        base_image = _parse_image(data[\"observation/exterior_image_1_left\"])\n        wrist_image = _parse_image(data[\"observation/wrist_image_left\"])\n\n        match self.model_type:\n            case _model.ModelType.PI0 | _model.ModelType.PI05:\n                names = (\"base_0_rgb\", \"left_wrist_0_rgb\", \"right_wrist_0_rgb\")\n                images = (base_image, wrist_image, np.zeros_like(base_image))\n                image_masks = (np.True_, np.True_, np.False_)\n            case _model.ModelType.PI0_FAST:\n                names = (\"base_0_rgb\", \"base_1_rgb\", \"wrist_0_rgb\")\n                # We don't mask out padding images for FAST models.\n                images = (base_image, np.zeros_like(base_image), wrist_image)\n                image_masks = (np.True_, np.True_, np.True_)\n            case _:\n                raise ValueError(f\"Unsupported model type: {self.model_type}\")\n\n        inputs = {\n            \"state\": state,\n            \"image\": dict(zip(names, images, strict=True)),\n            \"image_mask\": dict(zip(names, image_masks, strict=True)),\n        }\n\n        if \"actions\" in data:\n            inputs[\"actions\"] = np.asarray(data[\"actions\"])\n\n        if \"prompt\" in data:\n            if isinstance(data[\"prompt\"], bytes):\n                data[\"prompt\"] = data[\"prompt\"].decode(\"utf-8\")\n            inputs[\"prompt\"] = data[\"prompt\"]\n\n        return inputs\n\n\n@dataclasses.dataclass(frozen=True)\nclass DroidOutputs(transforms.DataTransformFn):\n    def __call__(self, data: dict) -> dict:\n        # Only return the first 8 dims.\n        return {\"actions\": np.asarray(data[\"actions\"][:, :8])}\n"
  },
  {
    "path": "src/openpi/policies/libero_policy.py",
    "content": "import dataclasses\n\nimport einops\nimport numpy as np\n\nfrom openpi import transforms\nfrom openpi.models import model as _model\n\n\ndef make_libero_example() -> dict:\n    \"\"\"Creates a random input example for the Libero policy.\"\"\"\n    return {\n        \"observation/state\": np.random.rand(8),\n        \"observation/image\": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),\n        \"observation/wrist_image\": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),\n        \"prompt\": \"do something\",\n    }\n\n\ndef _parse_image(image) -> np.ndarray:\n    image = np.asarray(image)\n    if np.issubdtype(image.dtype, np.floating):\n        image = (255 * image).astype(np.uint8)\n    if image.shape[0] == 3:\n        image = einops.rearrange(image, \"c h w -> h w c\")\n    return image\n\n\n@dataclasses.dataclass(frozen=True)\nclass LiberoInputs(transforms.DataTransformFn):\n    \"\"\"\n    This class is used to convert inputs to the model to the expected format. It is used for both training and inference.\n\n    For your own dataset, you can copy this class and modify the keys based on the comments below to pipe\n    the correct elements of your dataset into the model.\n    \"\"\"\n\n    # Determines which model will be used.\n    # Do not change this for your own dataset.\n    model_type: _model.ModelType\n\n    def __call__(self, data: dict) -> dict:\n        # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically\n        # stores as float32 (C,H,W), gets skipped for policy inference.\n        # Keep this for your own dataset, but if your dataset stores the images\n        # in a different key than \"observation/image\" or \"observation/wrist_image\",\n        # you should change it below.\n        # Pi0 models support three image inputs at the moment: one third-person view,\n        # and two wrist views (left and right). If your dataset does not have a particular type\n        # of image, e.g. wrist images, you can comment it out here and replace it with zeros like we do for the\n        # right wrist image below.\n        base_image = _parse_image(data[\"observation/image\"])\n        wrist_image = _parse_image(data[\"observation/wrist_image\"])\n\n        # Create inputs dict. Do not change the keys in the dict below.\n        inputs = {\n            \"state\": data[\"observation/state\"],\n            \"image\": {\n                \"base_0_rgb\": base_image,\n                \"left_wrist_0_rgb\": wrist_image,\n                # Pad any non-existent images with zero-arrays of the appropriate shape.\n                \"right_wrist_0_rgb\": np.zeros_like(base_image),\n            },\n            \"image_mask\": {\n                \"base_0_rgb\": np.True_,\n                \"left_wrist_0_rgb\": np.True_,\n                # We only mask padding images for pi0 model, not pi0-FAST. Do not change this for your own dataset.\n                \"right_wrist_0_rgb\": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_,\n            },\n        }\n\n        # Pad actions to the model action dimension. Keep this for your own dataset.\n        # Actions are only available during training.\n        if \"actions\" in data:\n            inputs[\"actions\"] = data[\"actions\"]\n\n        # Pass the prompt (aka language instruction) to the model.\n        # Keep this for your own dataset (but modify the key if the instruction is not\n        # stored in \"prompt\"; the output dict always needs to have the key \"prompt\").\n        if \"prompt\" in data:\n            inputs[\"prompt\"] = data[\"prompt\"]\n\n        return inputs\n\n\n@dataclasses.dataclass(frozen=True)\nclass LiberoOutputs(transforms.DataTransformFn):\n    \"\"\"\n    This class is used to convert outputs from the model back the the dataset specific format. It is\n    used for inference only.\n\n    For your own dataset, you can copy this class and modify the action dimension based on the comments below.\n    \"\"\"\n\n    def __call__(self, data: dict) -> dict:\n        # Only return the first N actions -- since we padded actions above to fit the model action\n        # dimension, we need to now parse out the correct number of actions in the return dict.\n        # For Libero, we only return the first 7 actions (since the rest is padding).\n        # For your own dataset, replace `7` with the action dimension of your dataset.\n        return {\"actions\": np.asarray(data[\"actions\"][:, :7])}\n"
  },
  {
    "path": "src/openpi/policies/policy.py",
    "content": "from collections.abc import Sequence\nimport logging\nimport pathlib\nimport time\nfrom typing import Any, TypeAlias\n\nimport flax\nimport flax.traverse_util\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom openpi_client import base_policy as _base_policy\nimport torch\nfrom typing_extensions import override\n\nfrom openpi import transforms as _transforms\nfrom openpi.models import model as _model\nfrom openpi.shared import array_typing as at\nfrom openpi.shared import nnx_utils\n\nBasePolicy: TypeAlias = _base_policy.BasePolicy\n\n\nclass Policy(BasePolicy):\n    def __init__(\n        self,\n        model: _model.BaseModel,\n        *,\n        rng: at.KeyArrayLike | None = None,\n        transforms: Sequence[_transforms.DataTransformFn] = (),\n        output_transforms: Sequence[_transforms.DataTransformFn] = (),\n        sample_kwargs: dict[str, Any] | None = None,\n        metadata: dict[str, Any] | None = None,\n        pytorch_device: str = \"cpu\",\n        is_pytorch: bool = False,\n    ):\n        \"\"\"Initialize the Policy.\n\n        Args:\n            model: The model to use for action sampling.\n            rng: Random number generator key for JAX models. Ignored for PyTorch models.\n            transforms: Input data transformations to apply before inference.\n            output_transforms: Output data transformations to apply after inference.\n            sample_kwargs: Additional keyword arguments to pass to model.sample_actions.\n            metadata: Additional metadata to store with the policy.\n            pytorch_device: Device to use for PyTorch models (e.g., \"cpu\", \"cuda:0\").\n                          Only relevant when is_pytorch=True.\n            is_pytorch: Whether the model is a PyTorch model. If False, assumes JAX model.\n        \"\"\"\n        self._model = model\n        self._input_transform = _transforms.compose(transforms)\n        self._output_transform = _transforms.compose(output_transforms)\n        self._sample_kwargs = sample_kwargs or {}\n        self._metadata = metadata or {}\n        self._is_pytorch_model = is_pytorch\n        self._pytorch_device = pytorch_device\n\n        if self._is_pytorch_model:\n            self._model = self._model.to(pytorch_device)\n            self._model.eval()\n            self._sample_actions = model.sample_actions\n        else:\n            # JAX model setup\n            self._sample_actions = nnx_utils.module_jit(model.sample_actions)\n            self._rng = rng or jax.random.key(0)\n\n    @override\n    def infer(self, obs: dict, *, noise: np.ndarray | None = None) -> dict:  # type: ignore[misc]\n        # Make a copy since transformations may modify the inputs in place.\n        inputs = jax.tree.map(lambda x: x, obs)\n        inputs = self._input_transform(inputs)\n        if not self._is_pytorch_model:\n            # Make a batch and convert to jax.Array.\n            inputs = jax.tree.map(lambda x: jnp.asarray(x)[np.newaxis, ...], inputs)\n            self._rng, sample_rng_or_pytorch_device = jax.random.split(self._rng)\n        else:\n            # Convert inputs to PyTorch tensors and move to correct device\n            inputs = jax.tree.map(lambda x: torch.from_numpy(np.array(x)).to(self._pytorch_device)[None, ...], inputs)\n            sample_rng_or_pytorch_device = self._pytorch_device\n\n        # Prepare kwargs for sample_actions\n        sample_kwargs = dict(self._sample_kwargs)\n        if noise is not None:\n            noise = torch.from_numpy(noise).to(self._pytorch_device) if self._is_pytorch_model else jnp.asarray(noise)\n\n            if noise.ndim == 2:  # If noise is (action_horizon, action_dim), add batch dimension\n                noise = noise[None, ...]  # Make it (1, action_horizon, action_dim)\n            sample_kwargs[\"noise\"] = noise\n\n        observation = _model.Observation.from_dict(inputs)\n        start_time = time.monotonic()\n        outputs = {\n            \"state\": inputs[\"state\"],\n            \"actions\": self._sample_actions(sample_rng_or_pytorch_device, observation, **sample_kwargs),\n        }\n        model_time = time.monotonic() - start_time\n        if self._is_pytorch_model:\n            outputs = jax.tree.map(lambda x: np.asarray(x[0, ...].detach().cpu()), outputs)\n        else:\n            outputs = jax.tree.map(lambda x: np.asarray(x[0, ...]), outputs)\n\n        outputs = self._output_transform(outputs)\n        outputs[\"policy_timing\"] = {\n            \"infer_ms\": model_time * 1000,\n        }\n        return outputs\n\n    @property\n    def metadata(self) -> dict[str, Any]:\n        return self._metadata\n\n\nclass PolicyRecorder(_base_policy.BasePolicy):\n    \"\"\"Records the policy's behavior to disk.\"\"\"\n\n    def __init__(self, policy: _base_policy.BasePolicy, record_dir: str):\n        self._policy = policy\n\n        logging.info(f\"Dumping policy records to: {record_dir}\")\n        self._record_dir = pathlib.Path(record_dir)\n        self._record_dir.mkdir(parents=True, exist_ok=True)\n        self._record_step = 0\n\n    @override\n    def infer(self, obs: dict) -> dict:  # type: ignore[misc]\n        results = self._policy.infer(obs)\n\n        data = {\"inputs\": obs, \"outputs\": results}\n        data = flax.traverse_util.flatten_dict(data, sep=\"/\")\n\n        output_path = self._record_dir / f\"step_{self._record_step}\"\n        self._record_step += 1\n\n        np.save(output_path, np.asarray(data))\n        return results\n"
  },
  {
    "path": "src/openpi/policies/policy_config.py",
    "content": "import logging\nimport os\nimport pathlib\nfrom typing import Any\n\nimport jax.numpy as jnp\n\nimport openpi.models.model as _model\nimport openpi.policies.policy as _policy\nimport openpi.shared.download as download\nfrom openpi.training import checkpoints as _checkpoints\nfrom openpi.training import config as _config\nimport openpi.transforms as transforms\n\n\ndef create_trained_policy(\n    train_config: _config.TrainConfig,\n    checkpoint_dir: pathlib.Path | str,\n    *,\n    repack_transforms: transforms.Group | None = None,\n    sample_kwargs: dict[str, Any] | None = None,\n    default_prompt: str | None = None,\n    norm_stats: dict[str, transforms.NormStats] | None = None,\n    pytorch_device: str | None = None,\n) -> _policy.Policy:\n    \"\"\"Create a policy from a trained checkpoint.\n\n    Args:\n        train_config: The training config to use to create the model.\n        checkpoint_dir: The directory to load the model from.\n        repack_transforms: Optional transforms that will be applied before any other transforms.\n        sample_kwargs: The kwargs to pass to the `sample_actions` method. If not provided, the default\n            kwargs will be used.\n        default_prompt: The default prompt to use for the policy. Will inject the prompt into the input\n            data if it doesn't already exist.\n        norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded\n            from the checkpoint directory.\n        pytorch_device: Device to use for PyTorch models (e.g., \"cpu\", \"cuda\", \"cuda:0\").\n                      If None and is_pytorch=True, will use \"cuda\" if available, otherwise \"cpu\".\n\n    Note:\n        The function automatically detects whether the model is PyTorch-based by checking for the\n        presence of \"model.safensors\" in the checkpoint directory.\n    \"\"\"\n    repack_transforms = repack_transforms or transforms.Group()\n    checkpoint_dir = download.maybe_download(str(checkpoint_dir))\n\n    # Check if this is a PyTorch model by looking for model.safetensors\n    weight_path = os.path.join(checkpoint_dir, \"model.safetensors\")\n    is_pytorch = os.path.exists(weight_path)\n\n    logging.info(\"Loading model...\")\n    if is_pytorch:\n        model = train_config.model.load_pytorch(train_config, weight_path)\n        model.paligemma_with_expert.to_bfloat16_for_selected_params(\"bfloat16\")\n    else:\n        model = train_config.model.load(_model.restore_params(checkpoint_dir / \"params\", dtype=jnp.bfloat16))\n    data_config = train_config.data.create(train_config.assets_dirs, train_config.model)\n    if norm_stats is None:\n        # We are loading the norm stats from the checkpoint instead of the config assets dir to make sure\n        # that the policy is using the same normalization stats as the original training process.\n        if data_config.asset_id is None:\n            raise ValueError(\"Asset id is required to load norm stats.\")\n        norm_stats = _checkpoints.load_norm_stats(checkpoint_dir / \"assets\", data_config.asset_id)\n\n    # Determine the device to use for PyTorch models\n    if is_pytorch and pytorch_device is None:\n        try:\n            import torch\n\n            pytorch_device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        except ImportError:\n            pytorch_device = \"cpu\"\n\n    return _policy.Policy(\n        model,\n        transforms=[\n            *repack_transforms.inputs,\n            transforms.InjectDefaultPrompt(default_prompt),\n            *data_config.data_transforms.inputs,\n            transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),\n            *data_config.model_transforms.inputs,\n        ],\n        output_transforms=[\n            *data_config.model_transforms.outputs,\n            transforms.Unnormalize(norm_stats, use_quantiles=data_config.use_quantile_norm),\n            *data_config.data_transforms.outputs,\n            *repack_transforms.outputs,\n        ],\n        sample_kwargs=sample_kwargs,\n        metadata=train_config.policy_metadata,\n        is_pytorch=is_pytorch,\n        pytorch_device=pytorch_device if is_pytorch else None,\n    )\n"
  },
  {
    "path": "src/openpi/policies/policy_test.py",
    "content": "from openpi_client import action_chunk_broker\nimport pytest\n\nfrom openpi.policies import aloha_policy\nfrom openpi.policies import policy_config as _policy_config\nfrom openpi.training import config as _config\n\n\n@pytest.mark.manual\ndef test_infer():\n    config = _config.get_config(\"pi0_aloha_sim\")\n    policy = _policy_config.create_trained_policy(config, \"gs://openpi-assets/checkpoints/pi0_aloha_sim\")\n\n    example = aloha_policy.make_aloha_example()\n    result = policy.infer(example)\n\n    assert result[\"actions\"].shape == (config.model.action_horizon, 14)\n\n\n@pytest.mark.manual\ndef test_broker():\n    config = _config.get_config(\"pi0_aloha_sim\")\n    policy = _policy_config.create_trained_policy(config, \"gs://openpi-assets/checkpoints/pi0_aloha_sim\")\n\n    broker = action_chunk_broker.ActionChunkBroker(\n        policy,\n        # Only execute the first half of the chunk.\n        action_horizon=config.model.action_horizon // 2,\n    )\n\n    example = aloha_policy.make_aloha_example()\n    for _ in range(config.model.action_horizon):\n        outputs = broker.infer(example)\n        assert outputs[\"actions\"].shape == (14,)\n"
  },
  {
    "path": "src/openpi/py.typed",
    "content": ""
  },
  {
    "path": "src/openpi/serving/websocket_policy_server.py",
    "content": "import asyncio\nimport http\nimport logging\nimport time\nimport traceback\n\nfrom openpi_client import base_policy as _base_policy\nfrom openpi_client import msgpack_numpy\nimport websockets.asyncio.server as _server\nimport websockets.frames\n\nlogger = logging.getLogger(__name__)\n\n\nclass WebsocketPolicyServer:\n    \"\"\"Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.\n\n    Currently only implements the `load` and `infer` methods.\n    \"\"\"\n\n    def __init__(\n        self,\n        policy: _base_policy.BasePolicy,\n        host: str = \"0.0.0.0\",\n        port: int | None = None,\n        metadata: dict | None = None,\n    ) -> None:\n        self._policy = policy\n        self._host = host\n        self._port = port\n        self._metadata = metadata or {}\n        logging.getLogger(\"websockets.server\").setLevel(logging.INFO)\n\n    def serve_forever(self) -> None:\n        asyncio.run(self.run())\n\n    async def run(self):\n        async with _server.serve(\n            self._handler,\n            self._host,\n            self._port,\n            compression=None,\n            max_size=None,\n            process_request=_health_check,\n        ) as server:\n            await server.serve_forever()\n\n    async def _handler(self, websocket: _server.ServerConnection):\n        logger.info(f\"Connection from {websocket.remote_address} opened\")\n        packer = msgpack_numpy.Packer()\n\n        await websocket.send(packer.pack(self._metadata))\n\n        prev_total_time = None\n        while True:\n            try:\n                start_time = time.monotonic()\n                obs = msgpack_numpy.unpackb(await websocket.recv())\n\n                infer_time = time.monotonic()\n                action = self._policy.infer(obs)\n                infer_time = time.monotonic() - infer_time\n\n                action[\"server_timing\"] = {\n                    \"infer_ms\": infer_time * 1000,\n                }\n                if prev_total_time is not None:\n                    # We can only record the last total time since we also want to include the send time.\n                    action[\"server_timing\"][\"prev_total_ms\"] = prev_total_time * 1000\n\n                await websocket.send(packer.pack(action))\n                prev_total_time = time.monotonic() - start_time\n\n            except websockets.ConnectionClosed:\n                logger.info(f\"Connection from {websocket.remote_address} closed\")\n                break\n            except Exception:\n                await websocket.send(traceback.format_exc())\n                await websocket.close(\n                    code=websockets.frames.CloseCode.INTERNAL_ERROR,\n                    reason=\"Internal server error. Traceback included in previous frame.\",\n                )\n                raise\n\n\ndef _health_check(connection: _server.ServerConnection, request: _server.Request) -> _server.Response | None:\n    if request.path == \"/healthz\":\n        return connection.respond(http.HTTPStatus.OK, \"OK\\n\")\n    # Continue with the normal request handling.\n    return None\n"
  },
  {
    "path": "src/openpi/shared/__init__.py",
    "content": ""
  },
  {
    "path": "src/openpi/shared/array_typing.py",
    "content": "import contextlib\nimport functools as ft\nimport inspect\nfrom typing import TypeAlias, TypeVar, cast\n\nimport beartype\nimport jax\nimport jax._src.tree_util as private_tree_util\nimport jax.core\nfrom jaxtyping import ArrayLike\nfrom jaxtyping import Bool  # noqa: F401\nfrom jaxtyping import DTypeLike  # noqa: F401\nfrom jaxtyping import Float\nfrom jaxtyping import Int  # noqa: F401\nfrom jaxtyping import Key  # noqa: F401\nfrom jaxtyping import Num  # noqa: F401\nfrom jaxtyping import PyTree\nfrom jaxtyping import Real  # noqa: F401\nfrom jaxtyping import UInt8  # noqa: F401\nfrom jaxtyping import config\nfrom jaxtyping import jaxtyped\nimport jaxtyping._decorator\nimport torch\n\n# patch jaxtyping to handle https://github.com/patrick-kidger/jaxtyping/issues/277.\n# the problem is that custom PyTree nodes are sometimes initialized with arbitrary types (e.g., `jax.ShapeDtypeStruct`,\n# `jax.Sharding`, or even <object>) due to JAX tracing operations. this patch skips typechecking when the stack trace\n# contains `jax._src.tree_util`, which should only be the case during tree unflattening.\n_original_check_dataclass_annotations = jaxtyping._decorator._check_dataclass_annotations  # noqa: SLF001\n# Redefine Array to include both JAX arrays and PyTorch tensors\nArray = jax.Array | torch.Tensor\n\n\ndef _check_dataclass_annotations(self, typechecker):\n    if not any(\n        frame.frame.f_globals.get(\"__name__\") in {\"jax._src.tree_util\", \"flax.nnx.transforms.compilation\"}\n        for frame in inspect.stack()\n    ):\n        return _original_check_dataclass_annotations(self, typechecker)\n    return None\n\n\njaxtyping._decorator._check_dataclass_annotations = _check_dataclass_annotations  # noqa: SLF001\n\nKeyArrayLike: TypeAlias = jax.typing.ArrayLike\nParams: TypeAlias = PyTree[Float[ArrayLike, \"...\"]]\n\nT = TypeVar(\"T\")\n\n\n# runtime type-checking decorator\ndef typecheck(t: T) -> T:\n    return cast(T, ft.partial(jaxtyped, typechecker=beartype.beartype)(t))\n\n\n@contextlib.contextmanager\ndef disable_typechecking():\n    initial = config.jaxtyping_disable\n    config.update(\"jaxtyping_disable\", True)  # noqa: FBT003\n    yield\n    config.update(\"jaxtyping_disable\", initial)\n\n\ndef check_pytree_equality(*, expected: PyTree, got: PyTree, check_shapes: bool = False, check_dtypes: bool = False):\n    \"\"\"Checks that two PyTrees have the same structure and optionally checks shapes and dtypes. Creates a much nicer\n    error message than if `jax.tree.map` is naively used on PyTrees with different structures.\n    \"\"\"\n\n    if errors := list(private_tree_util.equality_errors(expected, got)):\n        raise ValueError(\n            \"PyTrees have different structure:\\n\"\n            + (\n                \"\\n\".join(\n                    f\"   - at keypath '{jax.tree_util.keystr(path)}': expected {thing1}, got {thing2}, so {explanation}.\\n\"\n                    for path, thing1, thing2, explanation in errors\n                )\n            )\n        )\n\n    if check_shapes or check_dtypes:\n\n        def check(kp, x, y):\n            if check_shapes and x.shape != y.shape:\n                raise ValueError(f\"Shape mismatch at {jax.tree_util.keystr(kp)}: expected {x.shape}, got {y.shape}\")\n\n            if check_dtypes and x.dtype != y.dtype:\n                raise ValueError(f\"Dtype mismatch at {jax.tree_util.keystr(kp)}: expected {x.dtype}, got {y.dtype}\")\n\n        jax.tree_util.tree_map_with_path(check, expected, got)\n"
  },
  {
    "path": "src/openpi/shared/download.py",
    "content": "import concurrent.futures\nimport datetime\nimport logging\nimport os\nimport pathlib\nimport re\nimport shutil\nimport stat\nimport subprocess\nimport time\nimport urllib.parse\n\nimport filelock\nimport fsspec\nimport fsspec.generic\nimport tqdm_loggable.auto as tqdm\n\n# Environment variable to control cache directory path, ~/.cache/openpi will be used by default.\n_OPENPI_DATA_HOME = \"OPENPI_DATA_HOME\"\nDEFAULT_CACHE_DIR = \"~/.cache/openpi\"\n\nlogger = logging.getLogger(__name__)\n\n\ndef get_cache_dir() -> pathlib.Path:\n    cache_dir = pathlib.Path(os.getenv(_OPENPI_DATA_HOME, DEFAULT_CACHE_DIR)).expanduser().resolve()\n    cache_dir.mkdir(parents=True, exist_ok=True)\n    _set_folder_permission(cache_dir)\n    return cache_dir\n\n\ndef maybe_download(url: str, *, force_download: bool = False, **kwargs) -> pathlib.Path:\n    \"\"\"Download a file or directory from a remote filesystem to the local cache, and return the local path.\n\n    If the local file already exists, it will be returned directly.\n\n    It is safe to call this function concurrently from multiple processes.\n    See `get_cache_dir` for more details on the cache directory.\n\n    Args:\n        url: URL to the file to download.\n        force_download: If True, the file will be downloaded even if it already exists in the cache.\n        **kwargs: Additional arguments to pass to fsspec.\n\n    Returns:\n        Local path to the downloaded file or directory. That path is guaranteed to exist and is absolute.\n    \"\"\"\n    # Don't use fsspec to parse the url to avoid unnecessary connection to the remote filesystem.\n    parsed = urllib.parse.urlparse(url)\n\n    # Short circuit if this is a local path.\n    if parsed.scheme == \"\":\n        path = pathlib.Path(url)\n        if not path.exists():\n            raise FileNotFoundError(f\"File not found at {url}\")\n        return path.resolve()\n\n    cache_dir = get_cache_dir()\n\n    local_path = cache_dir / parsed.netloc / parsed.path.strip(\"/\")\n    local_path = local_path.resolve()\n\n    # Check if the cache should be invalidated.\n    invalidate_cache = False\n    if local_path.exists():\n        if force_download or _should_invalidate_cache(cache_dir, local_path):\n            invalidate_cache = True\n        else:\n            return local_path\n\n    try:\n        lock_path = local_path.with_suffix(\".lock\")\n        with filelock.FileLock(lock_path):\n            # Ensure consistent permissions for the lock file.\n            _ensure_permissions(lock_path)\n            # First, remove the existing cache if it is expired.\n            if invalidate_cache:\n                logger.info(f\"Removing expired cached entry: {local_path}\")\n                if local_path.is_dir():\n                    shutil.rmtree(local_path)\n                else:\n                    local_path.unlink()\n\n            if not local_path.exists():\n                # Download the data to a local cache.\n                logger.info(f\"Downloading {url} to {local_path}\")\n                scratch_path = local_path.with_suffix(\".partial\")\n                # Route openpi-assets through gsutil to avoid gcsfs auth issues with this bucket.\n                # All other gs:// URLs (e.g. big_vision) continue to use gcsfs as normal.\n                if parsed.scheme == \"gs\" and parsed.netloc == \"openpi-assets\":\n                    _download_gsutil(url, scratch_path, **kwargs)\n                else:\n                    _download_fsspec(url, scratch_path, **kwargs)\n\n                shutil.move(scratch_path, local_path)\n                _ensure_permissions(local_path)\n\n    except PermissionError as e:\n        msg = (\n            f\"Local file permission error was encountered while downloading {url}. \"\n            f\"Please try again after removing the cached data using: `rm -rf {local_path}*`\"\n        )\n        raise PermissionError(msg) from e\n\n    return local_path\n\n\ndef _download_gsutil(url: str, local_path: pathlib.Path, **kwargs) -> None:\n    \"\"\"Download a file or directory from GCS using gsutil if available, otherwise fall back to gcsfs.\"\"\"\n    if shutil.which(\"gsutil\") is None:\n        logger.warning(\n            \"gsutil not found, falling back to gcsfs. This may fail if GCP credentials are not configured correctly.\"\n        )\n        _download_fsspec(url, local_path, **kwargs)\n        return\n    local_path.mkdir(parents=True, exist_ok=True)\n    subprocess.run(\n        [\"gsutil\", \"-m\", \"cp\", \"-r\", f\"{url}/*\", str(local_path)],\n        check=True,\n    )\n\n\ndef _download_fsspec(url: str, local_path: pathlib.Path, **kwargs) -> None:\n    \"\"\"Download a file from a remote filesystem to the local cache, and return the local path.\"\"\"\n    fs, _ = fsspec.core.url_to_fs(url, **kwargs)\n    info = fs.info(url)\n    # Folders are represented by 0-byte objects with a trailing forward slash.\n    if is_dir := (info[\"type\"] == \"directory\" or (info[\"size\"] == 0 and info[\"name\"].endswith(\"/\"))):\n        total_size = fs.du(url)\n    else:\n        total_size = info[\"size\"]\n    with tqdm.tqdm(total=total_size, unit=\"iB\", unit_scale=True, unit_divisor=1024) as pbar:\n        executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)\n        future = executor.submit(fs.get, url, local_path, recursive=is_dir)\n        while not future.done():\n            current_size = sum(f.stat().st_size for f in [*local_path.rglob(\"*\"), local_path] if f.is_file())\n            pbar.update(current_size - pbar.n)\n            time.sleep(1)\n        pbar.update(total_size - pbar.n)\n\n\ndef _set_permission(path: pathlib.Path, target_permission: int):\n    \"\"\"chmod requires executable permission to be set, so we skip if the permission is already match with the target.\"\"\"\n    if path.stat().st_mode & target_permission == target_permission:\n        logger.debug(f\"Skipping {path} because it already has correct permissions\")\n        return\n    path.chmod(target_permission)\n    logger.debug(f\"Set {path} to {target_permission}\")\n\n\ndef _set_folder_permission(folder_path: pathlib.Path) -> None:\n    \"\"\"Set folder permission to be read, write and searchable.\"\"\"\n    _set_permission(folder_path, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)\n\n\ndef _ensure_permissions(path: pathlib.Path) -> None:\n    \"\"\"Since we are sharing cache directory with containerized runtime as well as training script, we need to\n    ensure that the cache directory has the correct permissions.\n    \"\"\"\n\n    def _setup_folder_permission_between_cache_dir_and_path(path: pathlib.Path) -> None:\n        cache_dir = get_cache_dir()\n        relative_path = path.relative_to(cache_dir)\n        moving_path = cache_dir\n        for part in relative_path.parts:\n            _set_folder_permission(moving_path / part)\n            moving_path = moving_path / part\n\n    def _set_file_permission(file_path: pathlib.Path) -> None:\n        \"\"\"Set all files to be read & writable, if it is a script, keep it as a script.\"\"\"\n        file_rw = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IWGRP | stat.S_IROTH | stat.S_IWOTH\n        if file_path.stat().st_mode & 0o100:\n            _set_permission(file_path, file_rw | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)\n        else:\n            _set_permission(file_path, file_rw)\n\n    _setup_folder_permission_between_cache_dir_and_path(path)\n    for root, dirs, files in os.walk(str(path)):\n        root_path = pathlib.Path(root)\n        for file in files:\n            file_path = root_path / file\n            _set_file_permission(file_path)\n\n        for dir in dirs:\n            dir_path = root_path / dir\n            _set_folder_permission(dir_path)\n\n\ndef _get_mtime(year: int, month: int, day: int) -> float:\n    \"\"\"Get the mtime of a given date at midnight UTC.\"\"\"\n    date = datetime.datetime(year, month, day, tzinfo=datetime.UTC)\n    return time.mktime(date.timetuple())\n\n\n# Map of relative paths, defined as regular expressions, to expiration timestamps (mtime format).\n# Partial matching will be used from top to bottom and the first match will be chosen.\n# Cached entries will be retained only if they are newer than the expiration timestamp.\n_INVALIDATE_CACHE_DIRS: dict[re.Pattern, float] = {\n    re.compile(\"openpi-assets/checkpoints/pi0_aloha_pen_uncap\"): _get_mtime(2025, 2, 17),\n    re.compile(\"openpi-assets/checkpoints/pi0_libero\"): _get_mtime(2025, 2, 6),\n    re.compile(\"openpi-assets/checkpoints/\"): _get_mtime(2025, 2, 3),\n}\n\n\ndef _should_invalidate_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool:\n    \"\"\"Invalidate the cache if it is expired. Return True if the cache was invalidated.\"\"\"\n\n    assert local_path.exists(), f\"File not found at {local_path}\"\n\n    relative_path = str(local_path.relative_to(cache_dir))\n    for pattern, expire_time in _INVALIDATE_CACHE_DIRS.items():\n        if pattern.match(relative_path):\n            # Remove if not newer than the expiration timestamp.\n            return local_path.stat().st_mtime <= expire_time\n\n    return False\n"
  },
  {
    "path": "src/openpi/shared/download_test.py",
    "content": "import pathlib\n\nimport pytest\n\nimport openpi.shared.download as download\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef set_openpi_data_home(tmp_path_factory):\n    temp_dir = tmp_path_factory.mktemp(\"openpi_data\")\n    with pytest.MonkeyPatch().context() as mp:\n        mp.setenv(\"OPENPI_DATA_HOME\", str(temp_dir))\n        yield\n\n\ndef test_download_local(tmp_path: pathlib.Path):\n    local_path = tmp_path / \"local\"\n    local_path.touch()\n\n    result = download.maybe_download(str(local_path))\n    assert result == local_path\n\n    with pytest.raises(FileNotFoundError):\n        download.maybe_download(\"bogus\")\n\n\ndef test_download_gs_dir():\n    remote_path = \"gs://openpi-assets/testdata/random\"\n\n    local_path = download.maybe_download(remote_path)\n    assert local_path.exists()\n\n    new_local_path = download.maybe_download(remote_path)\n    assert new_local_path == local_path\n\n\ndef test_download_gs():\n    remote_path = \"gs://openpi-assets/testdata/random/random_512kb.bin\"\n\n    local_path = download.maybe_download(remote_path)\n    assert local_path.exists()\n\n    new_local_path = download.maybe_download(remote_path)\n    assert new_local_path == local_path\n\n\ndef test_download_fsspec():\n    remote_path = \"gs://big_vision/paligemma_tokenizer.model\"\n\n    local_path = download.maybe_download(remote_path, gs={\"token\": \"anon\"})\n    assert local_path.exists()\n\n    new_local_path = download.maybe_download(remote_path, gs={\"token\": \"anon\"})\n    assert new_local_path == local_path\n"
  },
  {
    "path": "src/openpi/shared/image_tools.py",
    "content": "import functools\n\nimport jax\nimport jax.numpy as jnp\nimport torch\nimport torch.nn.functional as F  # noqa: N812\n\nimport openpi.shared.array_typing as at\n\n\n@functools.partial(jax.jit, static_argnums=(1, 2, 3))\n@at.typecheck\ndef resize_with_pad(\n    images: at.UInt8[at.Array, \"*b h w c\"] | at.Float[at.Array, \"*b h w c\"],\n    height: int,\n    width: int,\n    method: jax.image.ResizeMethod = jax.image.ResizeMethod.LINEAR,\n) -> at.UInt8[at.Array, \"*b {height} {width} c\"] | at.Float[at.Array, \"*b {height} {width} c\"]:\n    \"\"\"Replicates tf.image.resize_with_pad. Resizes an image to a target height and width without distortion\n    by padding with black. If the image is float32, it must be in the range [-1, 1].\n    \"\"\"\n    has_batch_dim = images.ndim == 4\n    if not has_batch_dim:\n        images = images[None]  # type: ignore\n    cur_height, cur_width = images.shape[1:3]\n    ratio = max(cur_width / width, cur_height / height)\n    resized_height = int(cur_height / ratio)\n    resized_width = int(cur_width / ratio)\n    resized_images = jax.image.resize(\n        images, (images.shape[0], resized_height, resized_width, images.shape[3]), method=method\n    )\n    if images.dtype == jnp.uint8:\n        # round from float back to uint8\n        resized_images = jnp.round(resized_images).clip(0, 255).astype(jnp.uint8)\n    elif images.dtype == jnp.float32:\n        resized_images = resized_images.clip(-1.0, 1.0)\n    else:\n        raise ValueError(f\"Unsupported image dtype: {images.dtype}\")\n\n    pad_h0, remainder_h = divmod(height - resized_height, 2)\n    pad_h1 = pad_h0 + remainder_h\n    pad_w0, remainder_w = divmod(width - resized_width, 2)\n    pad_w1 = pad_w0 + remainder_w\n    padded_images = jnp.pad(\n        resized_images,\n        ((0, 0), (pad_h0, pad_h1), (pad_w0, pad_w1), (0, 0)),\n        constant_values=0 if images.dtype == jnp.uint8 else -1.0,\n    )\n\n    if not has_batch_dim:\n        padded_images = padded_images[0]\n    return padded_images\n\n\ndef resize_with_pad_torch(\n    images: torch.Tensor,\n    height: int,\n    width: int,\n    mode: str = \"bilinear\",\n) -> torch.Tensor:\n    \"\"\"PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion\n    by padding with black. If the image is float32, it must be in the range [-1, 1].\n\n    Args:\n        images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]\n        height: Target height\n        width: Target width\n        mode: Interpolation mode ('bilinear', 'nearest', etc.)\n\n    Returns:\n        Resized and padded tensor with same shape format as input\n    \"\"\"\n    # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]\n    if images.shape[-1] <= 4:  # Assume channels-last format\n        channels_last = True\n        # Convert to channels-first for torch operations\n        if images.dim() == 3:\n            images = images.unsqueeze(0)  # Add batch dimension\n        images = images.permute(0, 3, 1, 2)  # [b, h, w, c] -> [b, c, h, w]\n    else:\n        channels_last = False\n        if images.dim() == 3:\n            images = images.unsqueeze(0)  # Add batch dimension\n\n    batch_size, channels, cur_height, cur_width = images.shape\n\n    # Calculate resize ratio\n    ratio = max(cur_width / width, cur_height / height)\n    resized_height = int(cur_height / ratio)\n    resized_width = int(cur_width / ratio)\n\n    # Resize\n    resized_images = F.interpolate(\n        images, size=(resized_height, resized_width), mode=mode, align_corners=False if mode == \"bilinear\" else None\n    )\n\n    # Handle dtype-specific clipping\n    if images.dtype == torch.uint8:\n        resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)\n    elif images.dtype == torch.float32:\n        resized_images = resized_images.clamp(-1.0, 1.0)\n    else:\n        raise ValueError(f\"Unsupported image dtype: {images.dtype}\")\n\n    # Calculate padding\n    pad_h0, remainder_h = divmod(height - resized_height, 2)\n    pad_h1 = pad_h0 + remainder_h\n    pad_w0, remainder_w = divmod(width - resized_width, 2)\n    pad_w1 = pad_w0 + remainder_w\n\n    # Pad\n    constant_value = 0 if images.dtype == torch.uint8 else -1.0\n    padded_images = F.pad(\n        resized_images,\n        (pad_w0, pad_w1, pad_h0, pad_h1),  # left, right, top, bottom\n        mode=\"constant\",\n        value=constant_value,\n    )\n\n    # Convert back to original format if needed\n    if channels_last:\n        padded_images = padded_images.permute(0, 2, 3, 1)  # [b, c, h, w] -> [b, h, w, c]\n        if batch_size == 1 and images.shape[0] == 1:\n            padded_images = padded_images.squeeze(0)  # Remove batch dimension if it was added\n\n    return padded_images\n"
  },
  {
    "path": "src/openpi/shared/image_tools_test.py",
    "content": "import jax.numpy as jnp\n\nfrom openpi.shared import image_tools\n\n\ndef test_resize_with_pad_shapes():\n    # Test case 1: Resize image with larger dimensions\n    images = jnp.zeros((2, 10, 10, 3), dtype=jnp.uint8)  # Input images of shape (batch_size, height, width, channels)\n    height = 20\n    width = 20\n    resized_images = image_tools.resize_with_pad(images, height, width)\n    assert resized_images.shape == (2, height, width, 3)\n    assert jnp.all(resized_images == 0)\n\n    # Test case 2: Resize image with smaller dimensions\n    images = jnp.zeros((3, 30, 30, 3), dtype=jnp.uint8)\n    height = 15\n    width = 15\n    resized_images = image_tools.resize_with_pad(images, height, width)\n    assert resized_images.shape == (3, height, width, 3)\n    assert jnp.all(resized_images == 0)\n\n    # Test case 3: Resize image with the same dimensions\n    images = jnp.zeros((1, 50, 50, 3), dtype=jnp.uint8)\n    height = 50\n    width = 50\n    resized_images = image_tools.resize_with_pad(images, height, width)\n    assert resized_images.shape == (1, height, width, 3)\n    assert jnp.all(resized_images == 0)\n\n    # Test case 3: Resize image with odd-numbered padding\n    images = jnp.zeros((1, 256, 320, 3), dtype=jnp.uint8)\n    height = 60\n    width = 80\n    resized_images = image_tools.resize_with_pad(images, height, width)\n    assert resized_images.shape == (1, height, width, 3)\n    assert jnp.all(resized_images == 0)\n"
  },
  {
    "path": "src/openpi/shared/nnx_utils.py",
    "content": "from collections.abc import Callable\nimport dataclasses\nimport functools\nimport inspect\nimport re\nfrom typing import Any, ParamSpec, TypeVar\n\nimport flax.nnx as nnx\nimport jax\n\nP = ParamSpec(\"P\")\nR = TypeVar(\"R\")\n\n\ndef module_jit(meth: Callable[P, R], *jit_args, **jit_kwargs) -> Callable[P, R]:\n    \"\"\"A higher-order function to JIT-compile `nnx.Module` methods, freezing the module's state in the process.\n\n    Why not `nnx.jit`? For some reason, naively applying `nnx.jit` to `nnx.Module` methods, bound or unbound, uses much\n    more memory than necessary. I'm guessing it has something to do with the fact that it must keep track of module\n    mutations. Also, `nnx.jit` has some inherent overhead compared to a standard `jax.jit`, since every call must\n    traverse the NNX module graph. See https://github.com/google/flax/discussions/4224 for details.\n\n    `module_jit` is an alternative that avoids these issues by freezing the module's state. The function returned by\n    `module_jit` acts exactly like the original method, except that the state of the module is frozen to whatever it was\n    when `module_jit` was called. Mutations to the module within `meth` are still allowed, but they will be discarded\n    after the method call completes.\n    \"\"\"\n    if not (inspect.ismethod(meth) and isinstance(meth.__self__, nnx.Module)):\n        raise ValueError(\"module_jit must only be used on bound methods of nnx.Modules.\")\n\n    graphdef, state = nnx.split(meth.__self__)\n\n    def fun(state: nnx.State, *args: P.args, **kwargs: P.kwargs) -> R:\n        module = nnx.merge(graphdef, state)\n        return meth.__func__(module, *args, **kwargs)\n\n    jitted_fn = jax.jit(fun, *jit_args, **jit_kwargs)\n\n    @functools.wraps(meth)\n    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:\n        return jitted_fn(state, *args, **kwargs)\n\n    return wrapper\n\n\n@dataclasses.dataclass(frozen=True)\nclass PathRegex:\n    \"\"\"NNX Filter that matches paths using a regex.\n\n    By default, paths are joined with a `/` separator. This can be overridden by setting the `sep` argument.\n    \"\"\"\n\n    pattern: str | re.Pattern\n    sep: str = \"/\"\n\n    def __post_init__(self):\n        if not isinstance(self.pattern, re.Pattern):\n            object.__setattr__(self, \"pattern\", re.compile(self.pattern))\n\n    def __call__(self, path: nnx.filterlib.PathParts, x: Any) -> bool:\n        joined_path = self.sep.join(str(x) for x in path)\n        assert isinstance(self.pattern, re.Pattern)\n        return self.pattern.fullmatch(joined_path) is not None\n\n\ndef state_map(state: nnx.State, filter: nnx.filterlib.Filter, fn: Callable[[Any], Any]) -> nnx.State:\n    \"\"\"Apply a function to the leaves of the state that match the filter.\"\"\"\n    filtered_keys = set(state.filter(filter).flat_state())\n    return state.map(lambda k, v: fn(v) if k in filtered_keys else v)\n"
  },
  {
    "path": "src/openpi/shared/normalize.py",
    "content": "import json\nimport pathlib\n\nimport numpy as np\nimport numpydantic\nimport pydantic\n\n\n@pydantic.dataclasses.dataclass\nclass NormStats:\n    mean: numpydantic.NDArray\n    std: numpydantic.NDArray\n    q01: numpydantic.NDArray | None = None  # 1st quantile\n    q99: numpydantic.NDArray | None = None  # 99th quantile\n\n\nclass RunningStats:\n    \"\"\"Compute running statistics of a batch of vectors.\"\"\"\n\n    def __init__(self):\n        self._count = 0\n        self._mean = None\n        self._mean_of_squares = None\n        self._min = None\n        self._max = None\n        self._histograms = None\n        self._bin_edges = None\n        self._num_quantile_bins = 5000  # for computing quantiles on the fly\n\n    def update(self, batch: np.ndarray) -> None:\n        \"\"\"\n        Update the running statistics with a batch of vectors.\n\n        Args:\n            vectors (np.ndarray): An array where all dimensions except the last are batch dimensions.\n        \"\"\"\n        batch = batch.reshape(-1, batch.shape[-1])\n        num_elements, vector_length = batch.shape\n        if self._count == 0:\n            self._mean = np.mean(batch, axis=0)\n            self._mean_of_squares = np.mean(batch**2, axis=0)\n            self._min = np.min(batch, axis=0)\n            self._max = np.max(batch, axis=0)\n            self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)]\n            self._bin_edges = [\n                np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1)\n                for i in range(vector_length)\n            ]\n        else:\n            if vector_length != self._mean.size:\n                raise ValueError(\"The length of new vectors does not match the initialized vector length.\")\n            new_max = np.max(batch, axis=0)\n            new_min = np.min(batch, axis=0)\n            max_changed = np.any(new_max > self._max)\n            min_changed = np.any(new_min < self._min)\n            self._max = np.maximum(self._max, new_max)\n            self._min = np.minimum(self._min, new_min)\n\n            if max_changed or min_changed:\n                self._adjust_histograms()\n\n        self._count += num_elements\n\n        batch_mean = np.mean(batch, axis=0)\n        batch_mean_of_squares = np.mean(batch**2, axis=0)\n\n        # Update running mean and mean of squares.\n        self._mean += (batch_mean - self._mean) * (num_elements / self._count)\n        self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * (num_elements / self._count)\n\n        self._update_histograms(batch)\n\n    def get_statistics(self) -> NormStats:\n        \"\"\"\n        Compute and return the statistics of the vectors processed so far.\n\n        Returns:\n            dict: A dictionary containing the computed statistics.\n        \"\"\"\n        if self._count < 2:\n            raise ValueError(\"Cannot compute statistics for less than 2 vectors.\")\n\n        variance = self._mean_of_squares - self._mean**2\n        stddev = np.sqrt(np.maximum(0, variance))\n        q01, q99 = self._compute_quantiles([0.01, 0.99])\n        return NormStats(mean=self._mean, std=stddev, q01=q01, q99=q99)\n\n    def _adjust_histograms(self):\n        \"\"\"Adjust histograms when min or max changes.\"\"\"\n        for i in range(len(self._histograms)):\n            old_edges = self._bin_edges[i]\n            new_edges = np.linspace(self._min[i], self._max[i], self._num_quantile_bins + 1)\n\n            # Redistribute the existing histogram counts to the new bins\n            new_hist, _ = np.histogram(old_edges[:-1], bins=new_edges, weights=self._histograms[i])\n\n            self._histograms[i] = new_hist\n            self._bin_edges[i] = new_edges\n\n    def _update_histograms(self, batch: np.ndarray) -> None:\n        \"\"\"Update histograms with new vectors.\"\"\"\n        for i in range(batch.shape[1]):\n            hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i])\n            self._histograms[i] += hist\n\n    def _compute_quantiles(self, quantiles):\n        \"\"\"Compute quantiles based on histograms.\"\"\"\n        results = []\n        for q in quantiles:\n            target_count = q * self._count\n            q_values = []\n            for hist, edges in zip(self._histograms, self._bin_edges, strict=True):\n                cumsum = np.cumsum(hist)\n                idx = np.searchsorted(cumsum, target_count)\n                q_values.append(edges[idx])\n            results.append(np.array(q_values))\n        return results\n\n\nclass _NormStatsDict(pydantic.BaseModel):\n    norm_stats: dict[str, NormStats]\n\n\ndef serialize_json(norm_stats: dict[str, NormStats]) -> str:\n    \"\"\"Serialize the running statistics to a JSON string.\"\"\"\n    return _NormStatsDict(norm_stats=norm_stats).model_dump_json(indent=2)\n\n\ndef deserialize_json(data: str) -> dict[str, NormStats]:\n    \"\"\"Deserialize the running statistics from a JSON string.\"\"\"\n    return _NormStatsDict(**json.loads(data)).norm_stats\n\n\ndef save(directory: pathlib.Path | str, norm_stats: dict[str, NormStats]) -> None:\n    \"\"\"Save the normalization stats to a directory.\"\"\"\n    path = pathlib.Path(directory) / \"norm_stats.json\"\n    path.parent.mkdir(parents=True, exist_ok=True)\n    path.write_text(serialize_json(norm_stats))\n\n\ndef load(directory: pathlib.Path | str) -> dict[str, NormStats]:\n    \"\"\"Load the normalization stats from a directory.\"\"\"\n    path = pathlib.Path(directory) / \"norm_stats.json\"\n    if not path.exists():\n        raise FileNotFoundError(f\"Norm stats file not found at: {path}\")\n    return deserialize_json(path.read_text())\n"
  },
  {
    "path": "src/openpi/shared/normalize_test.py",
    "content": "import numpy as np\n\nimport openpi.shared.normalize as normalize\n\n\ndef test_normalize_update():\n    arr = np.arange(12).reshape(4, 3)  # 4 vectors of length 3\n\n    stats = normalize.RunningStats()\n    for i in range(len(arr)):\n        stats.update(arr[i : i + 1])  # Update with one vector at a time\n    results = stats.get_statistics()\n\n    assert np.allclose(results.mean, np.mean(arr, axis=0))\n    assert np.allclose(results.std, np.std(arr, axis=0))\n\n\ndef test_serialize_deserialize():\n    stats = normalize.RunningStats()\n    stats.update(np.arange(12).reshape(4, 3))  # 4 vectors of length 3\n\n    norm_stats = {\"test\": stats.get_statistics()}\n    norm_stats2 = normalize.deserialize_json(normalize.serialize_json(norm_stats))\n    assert np.allclose(norm_stats[\"test\"].mean, norm_stats2[\"test\"].mean)\n    assert np.allclose(norm_stats[\"test\"].std, norm_stats2[\"test\"].std)\n\n\ndef test_multiple_batch_dimensions():\n    # Test with multiple batch dimensions: (2, 3, 4) where 4 is vector dimension\n    batch_shape = (2, 3, 4)\n    arr = np.random.rand(*batch_shape)\n\n    stats = normalize.RunningStats()\n    stats.update(arr)  # Should handle (2, 3, 4) -> reshape to (6, 4)\n    results = stats.get_statistics()\n\n    # Flatten batch dimensions and compute expected stats\n    flattened = arr.reshape(-1, arr.shape[-1])  # (6, 4)\n    expected_mean = np.mean(flattened, axis=0)\n    expected_std = np.std(flattened, axis=0)\n\n    assert np.allclose(results.mean, expected_mean)\n    assert np.allclose(results.std, expected_std)\n"
  },
  {
    "path": "src/openpi/training/checkpoints.py",
    "content": "from __future__ import annotations\n\nimport asyncio\nimport concurrent.futures as futures\nimport dataclasses\nimport logging\nfrom typing import Protocol\n\nfrom etils import epath\nimport jax\nimport orbax.checkpoint as ocp\nimport orbax.checkpoint.future as future\n\nfrom openpi.shared import array_typing as at\nimport openpi.shared.normalize as _normalize\nimport openpi.training.data_loader as _data_loader\nimport openpi.training.utils as training_utils\n\n\ndef initialize_checkpoint_dir(\n    checkpoint_dir: epath.Path | str, *, keep_period: int | None, overwrite: bool, resume: bool\n) -> tuple[ocp.CheckpointManager, bool]:\n    checkpoint_dir = epath.Path(checkpoint_dir).resolve()\n    resuming = False\n    if checkpoint_dir.exists():\n        if overwrite:\n            checkpoint_dir.rmtree()\n            checkpoint_dir.mkdir(parents=True, exist_ok=True)\n            logging.info(f\"Wiped checkpoint directory {checkpoint_dir}\")\n        elif resume:\n            resuming = True\n        else:\n            raise FileExistsError(\n                f\"Checkpoint directory {checkpoint_dir} already exists. Use --overwrite or --resume \"\n                \"to indicate how to handle it.\"\n            )\n\n    checkpoint_dir.mkdir(parents=True, exist_ok=True)\n\n    mngr = ocp.CheckpointManager(\n        checkpoint_dir,\n        item_handlers={\n            \"assets\": CallbackHandler(),\n            \"train_state\": ocp.PyTreeCheckpointHandler(),\n            \"params\": ocp.PyTreeCheckpointHandler(),\n        },\n        options=ocp.CheckpointManagerOptions(\n            max_to_keep=1,\n            keep_period=keep_period,\n            create=False,\n            async_options=ocp.AsyncOptions(timeout_secs=7200),\n        ),\n    )\n\n    # Special case: the checkpoint directory exists and the user requests to resume training, but the training run did\n    # not get to the first checkpoint saved. In this case, we don't actually want the train script to try and restore a\n    # checkpoint, since it will fail.\n    if resuming and tuple(mngr.all_steps()) in [(), (0,)]:\n        logging.info(\"Checkpoint directory exists, but does not contain any checkpoints. Aborting resume.\")\n        resuming = False\n\n    return mngr, resuming\n\n\ndef save_state(\n    checkpoint_manager: ocp.CheckpointManager,\n    state: training_utils.TrainState,\n    data_loader: _data_loader.DataLoader,\n    step: int,\n):\n    def save_assets(directory: epath.Path):\n        # Save the normalization stats.\n        data_config = data_loader.data_config()\n        norm_stats = data_config.norm_stats\n        if norm_stats is not None and data_config.asset_id is not None:\n            _normalize.save(directory / data_config.asset_id, norm_stats)\n\n    # Split params that can be used for inference into a separate item.\n    with at.disable_typechecking():\n        train_state, params = _split_params(state)\n    items = {\n        \"assets\": save_assets,\n        \"train_state\": train_state,\n        \"params\": {\"params\": params},\n    }\n    checkpoint_manager.save(step, items)\n\n\ndef restore_state(\n    checkpoint_manager: ocp.CheckpointManager,\n    state: training_utils.TrainState,\n    data_loader: _data_loader.DataLoader,\n    step: int | None = None,\n) -> training_utils.TrainState:\n    del data_loader\n\n    with at.disable_typechecking():\n        # Split params that can be used for inference into a separate item.\n        train_state, params = _split_params(state)\n        restored = checkpoint_manager.restore(\n            step,\n            items={\n                \"train_state\": train_state,\n                \"params\": {\"params\": params},\n            },\n        )\n    return _merge_params(restored[\"train_state\"], restored[\"params\"])\n\n\ndef load_norm_stats(assets_dir: epath.Path | str, asset_id: str) -> dict[str, _normalize.NormStats] | None:\n    norm_stats_dir = epath.Path(assets_dir) / asset_id\n    norm_stats = _normalize.load(norm_stats_dir)\n    logging.info(f\"Loaded norm stats from {norm_stats_dir}\")\n    return norm_stats\n\n\nclass Callback(Protocol):\n    def __call__(self, directory: epath.Path) -> None: ...\n\n\nclass CallbackHandler(ocp.AsyncCheckpointHandler):\n    \"\"\"A CheckpointHandler for calling an arbitrary function asynchronously. Only for saving, not for restoring.\"\"\"\n\n    def save(self, directory: epath.Path, args: CallbackSave):\n        if jax.process_index() == 0:\n            args.callback(directory)\n\n    async def async_save(self, directory: epath.Path, args: CallbackSave) -> list[futures.Future]:\n        return [future.CommitFutureAwaitingContractedSignals(asyncio.to_thread(self.save, directory, args))]\n\n    def restore(self, *args, **kwargs):\n        raise NotImplementedError(\"CallbackHandler does not support restore\")\n\n\n@ocp.args.register_with_handler(CallbackHandler, for_save=True)\n@dataclasses.dataclass\nclass CallbackSave(ocp.args.CheckpointArgs):\n    callback: Callback\n\n\n@ocp.args.register_with_handler(CallbackHandler, for_restore=True)\nclass CallbackRestore(ocp.args.CheckpointArgs): ...\n\n\ndef _split_params(state: training_utils.TrainState) -> tuple[training_utils.TrainState, at.Params]:\n    if state.ema_params is not None:\n        params = state.ema_params\n        train_state = dataclasses.replace(state, ema_params=None)\n    else:\n        params = state.params\n        train_state = dataclasses.replace(state, params={})\n    return train_state, params\n\n\ndef _merge_params(train_state: training_utils.TrainState, params: dict[str, at.Params]) -> training_utils.TrainState:\n    # Revert the logic inside `_split_params`. Assumes that existence of `params` means that EMA params were used during the split.\n    if train_state.params:\n        return dataclasses.replace(train_state, ema_params=params[\"params\"])\n    return dataclasses.replace(train_state, params=params[\"params\"])\n"
  },
  {
    "path": "src/openpi/training/config.py",
    "content": "\"\"\"See _CONFIGS for the list of available configs.\"\"\"\n\nimport abc\nfrom collections.abc import Sequence\nimport dataclasses\nimport difflib\nimport logging\nimport pathlib\nfrom typing import Any, Literal, Protocol, TypeAlias\n\nimport etils.epath as epath\nimport flax.nnx as nnx\nfrom typing_extensions import override\nimport tyro\n\nimport openpi.models.model as _model\nimport openpi.models.pi0_config as pi0_config\nimport openpi.models.pi0_fast as pi0_fast\nimport openpi.models.tokenizer as _tokenizer\nimport openpi.policies.aloha_policy as aloha_policy\nimport openpi.policies.droid_policy as droid_policy\nimport openpi.policies.libero_policy as libero_policy\nimport openpi.shared.download as _download\nimport openpi.shared.normalize as _normalize\nimport openpi.training.droid_rlds_dataset as droid_rlds_dataset\nimport openpi.training.misc.polaris_config as polaris_config\nimport openpi.training.misc.roboarena_config as roboarena_config\nimport openpi.training.optimizer as _optimizer\nimport openpi.training.weight_loaders as weight_loaders\nimport openpi.transforms as _transforms\n\nModelType: TypeAlias = _model.ModelType\n# Work around a tyro issue with using nnx.filterlib.Filter directly.\nFilter: TypeAlias = nnx.filterlib.Filter\n\n\n@dataclasses.dataclass(frozen=True)\nclass AssetsConfig:\n    \"\"\"Determines the location of assets (e.g., norm stats) that will be used to set up the data pipeline.\n\n    These assets will be replicated inside the checkpoint under the `assets/asset_id` directory.\n\n    This can be used to load assets from a different checkpoint (e.g., base model checkpoint) or some other\n    centralized location. For example, to load the norm stats for the Trossen robot from the base model checkpoint\n    during fine-tuning, use:\n\n    ```\n    AssetsConfig(\n        assets_dir=\"gs://openpi-assets/checkpoints/pi0_base/assets\",\n        asset_id=\"trossen\",\n    )\n    ```\n    \"\"\"\n\n    # Assets directory. If not provided, the config assets_dirs will be used. This is useful to load assets from\n    # a different checkpoint (e.g., base model checkpoint) or some other centralized location.\n    assets_dir: str | None = None\n\n    # Asset id. If not provided, the repo id will be used. This allows users to reference assets that describe\n    # different robot platforms.\n    asset_id: str | None = None\n\n\n@dataclasses.dataclass(frozen=True)\nclass DataConfig:\n    # LeRobot repo id. If None, fake data will be created.\n    repo_id: str | None = None\n    # Directory within the assets directory containing the data assets.\n    asset_id: str | None = None\n    # Contains precomputed normalization stats. If None, normalization will not be performed.\n    norm_stats: dict[str, _transforms.NormStats] | None = None\n\n    # Used to adopt the inputs from a dataset specific format to a common format\n    # which is expected by the data transforms.\n    repack_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group)\n    # Data transforms, typically include robot specific transformations. Will be applied\n    # before the data is normalized. See `model.Observation` and `model.Actions` to learn about the\n    # normalized data.\n    data_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group)\n    # Model specific transforms. Will be applied after the data is normalized.\n    model_transforms: _transforms.Group = dataclasses.field(default_factory=_transforms.Group)\n    # If true, will use quantile normalization. Otherwise, normal z-score normalization will be used.\n    use_quantile_norm: bool = False\n\n    # Names of keys that will be used by the data loader to generate the action sequence. The length of the\n    # sequence is defined by the `action_horizon` field in the model config. This should be adjusted if your\n    # LeRobot dataset is using different keys to represent the action.\n    action_sequence_keys: Sequence[str] = (\"actions\",)\n\n    # If true, will use the LeRobot dataset task to define the prompt.\n    prompt_from_task: bool = False\n\n    # Only used for RLDS data loader (ie currently only used for DROID).\n    rlds_data_dir: str | None = None\n    # Action space for DROID dataset.\n    action_space: droid_rlds_dataset.DroidActionSpace | None = None\n    # List of datasets to sample from: name, version, weight, and optionally filter_dict_path\n    datasets: Sequence[droid_rlds_dataset.RLDSDataset] = ()\n\n\nclass GroupFactory(Protocol):\n    def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:\n        \"\"\"Create a group.\"\"\"\n\n\n@dataclasses.dataclass(frozen=True)\nclass ModelTransformFactory(GroupFactory):\n    \"\"\"Creates model transforms for standard pi0 models.\"\"\"\n\n    # If provided, will determine the default prompt that be used by the model.\n    default_prompt: str | None = None\n\n    def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:\n        match model_config.model_type:\n            case _model.ModelType.PI0:\n                return _transforms.Group(\n                    inputs=[\n                        _transforms.InjectDefaultPrompt(self.default_prompt),\n                        _transforms.ResizeImages(224, 224),\n                        _transforms.TokenizePrompt(\n                            _tokenizer.PaligemmaTokenizer(model_config.max_token_len),\n                        ),\n                        _transforms.PadStatesAndActions(model_config.action_dim),\n                    ],\n                )\n            case _model.ModelType.PI05:\n                assert isinstance(model_config, pi0_config.Pi0Config)\n                return _transforms.Group(\n                    inputs=[\n                        _transforms.InjectDefaultPrompt(self.default_prompt),\n                        _transforms.ResizeImages(224, 224),\n                        _transforms.TokenizePrompt(\n                            _tokenizer.PaligemmaTokenizer(model_config.max_token_len),\n                            discrete_state_input=model_config.discrete_state_input,\n                        ),\n                        _transforms.PadStatesAndActions(model_config.action_dim),\n                    ],\n                )\n            case _model.ModelType.PI0_FAST:\n                tokenizer_cls = (\n                    _tokenizer.FASTTokenizer\n                    if model_config.fast_model_tokenizer is None\n                    else model_config.fast_model_tokenizer\n                )\n                tokenizer_kwargs = (\n                    {} if model_config.fast_model_tokenizer_kwargs is None else model_config.fast_model_tokenizer_kwargs\n                )\n                return _transforms.Group(\n                    inputs=[\n                        _transforms.InjectDefaultPrompt(self.default_prompt),\n                        _transforms.ResizeImages(224, 224),\n                        _transforms.TokenizeFASTInputs(\n                            tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),\n                        ),\n                    ],\n                    outputs=[\n                        _transforms.ExtractFASTActions(\n                            tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),\n                            action_horizon=model_config.action_horizon,\n                            action_dim=model_config.action_dim,\n                        )\n                    ],\n                )\n\n\n@dataclasses.dataclass(frozen=True)\nclass DataConfigFactory(abc.ABC):\n    # The LeRobot repo id.\n    repo_id: str = tyro.MISSING\n    # Determines how the assets will be loaded.\n    assets: AssetsConfig = dataclasses.field(default_factory=AssetsConfig)\n    # Base config that will be updated by the factory.\n    base_config: tyro.conf.Suppress[DataConfig | None] = None\n\n    @abc.abstractmethod\n    def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:\n        \"\"\"Create a data config.\"\"\"\n\n    def create_base_config(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:\n        repo_id = self.repo_id if self.repo_id is not tyro.MISSING else None\n        asset_id = self.assets.asset_id or repo_id\n        return dataclasses.replace(\n            self.base_config or DataConfig(),\n            repo_id=repo_id,\n            asset_id=asset_id,\n            norm_stats=self._load_norm_stats(epath.Path(self.assets.assets_dir or assets_dirs), asset_id),\n            use_quantile_norm=model_config.model_type != ModelType.PI0,\n        )\n\n    def _load_norm_stats(self, assets_dir: epath.Path, asset_id: str | None) -> dict[str, _transforms.NormStats] | None:\n        if asset_id is None:\n            return None\n        try:\n            data_assets_dir = str(assets_dir / asset_id)\n            norm_stats = _normalize.load(_download.maybe_download(data_assets_dir))\n            logging.info(f\"Loaded norm stats from {data_assets_dir}\")\n            return norm_stats\n        except FileNotFoundError:\n            logging.info(f\"Norm stats not found in {data_assets_dir}, skipping.\")\n        return None\n\n\n@dataclasses.dataclass(frozen=True)\nclass FakeDataConfig(DataConfigFactory):\n    repo_id: str = \"fake\"\n\n    @override\n    def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:\n        return DataConfig(repo_id=self.repo_id)\n\n\n@dataclasses.dataclass(frozen=True)\nclass SimpleDataConfig(DataConfigFactory):\n    # Factory for the data transforms.\n    data_transforms: tyro.conf.Suppress[GroupFactory] = dataclasses.field(default_factory=GroupFactory)\n    # Factory for the model transforms.\n    model_transforms: tyro.conf.Suppress[GroupFactory] = dataclasses.field(default_factory=ModelTransformFactory)\n\n    @override\n    def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:\n        return dataclasses.replace(\n            self.create_base_config(assets_dirs, model_config),\n            data_transforms=self.data_transforms(model_config),\n            model_transforms=self.model_transforms(model_config),\n        )\n\n\n@dataclasses.dataclass(frozen=True)\nclass LeRobotAlohaDataConfig(DataConfigFactory):\n    # If true, will convert joint dimensions to deltas with respect to the current state before passing to the model.\n    # Gripper dimensions will remain in absolute values.\n    use_delta_joint_actions: bool = True\n    # If provided, will be injected into the input data if the \"prompt\" key is not present.\n    default_prompt: str | None = None\n    # If true, this will convert the joint and gripper values from the standard Aloha space to\n    # the space used by the pi internal runtime which was used to train the base model. People who\n    # use standard Aloha data should set this to true.\n    adapt_to_pi: bool = True\n\n    # Repack transforms.\n    repack_transforms: tyro.conf.Suppress[_transforms.Group] = dataclasses.field(\n        default=_transforms.Group(\n            inputs=[\n                _transforms.RepackTransform(\n                    {\n                        \"images\": {\"cam_high\": \"observation.images.top\"},\n                        \"state\": \"observation.state\",\n                        \"actions\": \"action\",\n                    }\n                )\n            ]\n        )\n    )\n    # Action keys that will be used to read the action sequence from the dataset.\n    action_sequence_keys: Sequence[str] = (\"action\",)\n\n    @override\n    def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:\n        data_transforms = _transforms.Group(\n            inputs=[aloha_policy.AlohaInputs(adapt_to_pi=self.adapt_to_pi)],\n            outputs=[aloha_policy.AlohaOutputs(adapt_to_pi=self.adapt_to_pi)],\n        )\n        if self.use_delta_joint_actions:\n            delta_action_mask = _transforms.make_bool_mask(6, -1, 6, -1)\n            data_transforms = data_transforms.push(\n                inputs=[_transforms.DeltaActions(delta_action_mask)],\n                outputs=[_transforms.AbsoluteActions(delta_action_mask)],\n            )\n\n        model_transforms = ModelTransformFactory(default_prompt=self.default_prompt)(model_config)\n\n        return dataclasses.replace(\n            self.create_base_config(assets_dirs, model_config),\n            repack_transforms=self.repack_transforms,\n            data_transforms=data_transforms,\n            model_transforms=model_transforms,\n            action_sequence_keys=self.action_sequence_keys,\n        )\n\n\n@dataclasses.dataclass(frozen=True)\nclass LeRobotLiberoDataConfig(DataConfigFactory):\n    \"\"\"\n    This config is used to configure transforms that are applied at various parts of the data pipeline.\n    For your own dataset, you can copy this class and modify the transforms to match your dataset based on the\n    comments below.\n    \"\"\"\n\n    extra_delta_transform: bool = False\n\n    @override\n    def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:\n        # The repack transform is *only* applied to the data coming from the dataset,\n        # and *not* during inference. We can use it to make inputs from the dataset look\n        # as close as possible to those coming from the inference environment (e.g. match the keys).\n        # Below, we match the keys in the dataset (which we defined in the data conversion script) to\n        # the keys we use in our inference pipeline (defined in the inference script for libero).\n        # For your own dataset, first figure out what keys your environment passes to the policy server\n        # and then modify the mappings below so your dataset's keys get matched to those target keys.\n        # The repack transform simply remaps key names here.\n        repack_transform = _transforms.Group(\n            inputs=[\n                _transforms.RepackTransform(\n                    {\n                        \"observation/image\": \"image\",\n                        \"observation/wrist_image\": \"wrist_image\",\n                        \"observation/state\": \"state\",\n                        \"actions\": \"actions\",\n                        \"prompt\": \"prompt\",\n                    }\n                )\n            ]\n        )\n\n        # The data transforms are applied to the data coming from the dataset *and* during inference.\n        # Below, we define the transforms for data going into the model (``inputs``) and the transforms\n        # for data coming out of the model (``outputs``) (the latter is only used during inference).\n        # We defined these transforms in `libero_policy.py`. You can check the detailed comments there for\n        # how to modify the transforms to match your dataset. Once you created your own transforms, you can\n        # replace the transforms below with your own.\n        data_transforms = _transforms.Group(\n            inputs=[libero_policy.LiberoInputs(model_type=model_config.model_type)],\n            outputs=[libero_policy.LiberoOutputs()],\n        )\n\n        # One additional data transform: pi0 models are trained on delta actions (relative to the first\n        # state in each action chunk). IF your data has ``absolute`` actions (e.g. target joint angles)\n        # you can uncomment the following line to convert the actions to delta actions. The only exception\n        # is for the gripper actions which are always absolute.\n        # In the example below, we would apply the delta conversion to the first 6 actions (joints) and\n        # leave the 7th action (gripper) unchanged, i.e. absolute.\n        # In Libero, the raw actions in the dataset are already delta actions, so we *do not* need to\n        # apply a separate delta conversion (that's why it's commented out). Choose whether to apply this\n        # transform based on whether your dataset uses ``absolute`` or ``delta`` actions out of the box.\n\n        # LIBERO already represents actions as deltas, but we have some old Pi0 checkpoints that are trained with this\n        # extra delta transform.\n        if self.extra_delta_transform:\n            delta_action_mask = _transforms.make_bool_mask(6, -1)\n            data_transforms = data_transforms.push(\n                inputs=[_transforms.DeltaActions(delta_action_mask)],\n                outputs=[_transforms.AbsoluteActions(delta_action_mask)],\n            )\n\n        # Model transforms include things like tokenizing the prompt and action targets\n        # You do not need to change anything here for your own dataset.\n        model_transforms = ModelTransformFactory()(model_config)\n\n        # We return all data transforms for training and inference. No need to change anything here.\n        return dataclasses.replace(\n            self.create_base_config(assets_dirs, model_config),\n            repack_transforms=repack_transform,\n            data_transforms=data_transforms,\n            model_transforms=model_transforms,\n        )\n\n\n@dataclasses.dataclass(frozen=True)\nclass RLDSDroidDataConfig(DataConfigFactory):\n    \"\"\"\n    Config for training on DROID, using RLDS data format (for efficient training on larger datasets).\n    \"\"\"\n\n    rlds_data_dir: str | None = None\n    action_space: droid_rlds_dataset.DroidActionSpace | None = None\n\n    # Filtering options. Can pass a path to a dictionary that maps episodes to timestep ranges\n    # to tuples denoting ranges of time steps to keep (start, end). Episodes are uniquely identified with\n    # f\"{recording_folderpath}--{file_path}\", both of which are present in the RLDS episode metadata.\n\n    # List of datasets to sample from: name, version, weight, and optionally filter_dict_path\n    datasets: Sequence[droid_rlds_dataset.RLDSDataset] = (\n        droid_rlds_dataset.RLDSDataset(\n            name=\"droid\",\n            version=\"1.0.1\",\n            weight=1.0,\n            filter_dict_path=\"gs://openpi-assets/droid/droid_sample_ranges_v1_0_1.json\",\n        ),\n    )\n\n    @override\n    def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:\n        repack_transform = _transforms.Group(\n            inputs=[\n                _transforms.RepackTransform(\n                    {\n                        \"observation/exterior_image_1_left\": \"observation/image\",\n                        \"observation/wrist_image_left\": \"observation/wrist_image\",\n                        \"observation/joint_position\": \"observation/joint_position\",\n                        \"observation/gripper_position\": \"observation/gripper_position\",\n                        \"actions\": \"actions\",\n                        \"prompt\": \"prompt\",\n                    }\n                )\n            ]\n        )\n\n        data_transforms = _transforms.Group(\n            inputs=[droid_policy.DroidInputs(model_type=model_config.model_type)],\n            outputs=[droid_policy.DroidOutputs()],\n        )\n\n        if self.action_space == droid_rlds_dataset.DroidActionSpace.JOINT_POSITION:\n            # Data loader returns absolute joint position actions -- convert to delta actions for training.\n            delta_action_mask = _transforms.make_bool_mask(7, -1)\n            data_transforms = data_transforms.push(\n                inputs=[_transforms.DeltaActions(delta_action_mask)],\n                outputs=[_transforms.AbsoluteActions(delta_action_mask)],\n            )\n\n        model_transforms = ModelTransformFactory()(model_config)\n\n        assert self.rlds_data_dir is not None, \"Need to set rlds data dir for RLDS data loader.\"\n\n        return dataclasses.replace(\n            self.create_base_config(assets_dirs, model_config),\n            repack_transforms=repack_transform,\n            data_transforms=data_transforms,\n            model_transforms=model_transforms,\n            rlds_data_dir=self.rlds_data_dir,\n            action_space=self.action_space,\n            datasets=self.datasets,\n        )\n\n\n@dataclasses.dataclass(frozen=True)\nclass LeRobotDROIDDataConfig(DataConfigFactory):\n    \"\"\"\n    Example data config for custom DROID dataset in LeRobot format.\n    To convert your custom DROID dataset (<10s of hours) to LeRobot format, see examples/droid/convert_droid_data_to_lerobot.py\n    \"\"\"\n\n    @override\n    def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:\n        repack_transform = _transforms.Group(\n            inputs=[\n                _transforms.RepackTransform(\n                    {\n                        \"observation/exterior_image_1_left\": \"exterior_image_1_left\",\n                        \"observation/exterior_image_2_left\": \"exterior_image_2_left\",\n                        \"observation/wrist_image_left\": \"wrist_image_left\",\n                        \"observation/joint_position\": \"joint_position\",\n                        \"observation/gripper_position\": \"gripper_position\",\n                        \"actions\": \"actions\",\n                        \"prompt\": \"prompt\",\n                    }\n                )\n            ]\n        )\n        # We assume joint *velocity* actions, so we should *not* apply an additional delta transform.\n        data_transforms = _transforms.Group(\n            inputs=[droid_policy.DroidInputs(model_type=model_config.model_type)],\n            outputs=[droid_policy.DroidOutputs()],\n        )\n        model_transforms = ModelTransformFactory()(model_config)\n\n        return dataclasses.replace(\n            self.create_base_config(assets_dirs, model_config),\n            repack_transforms=repack_transform,\n            data_transforms=data_transforms,\n            model_transforms=model_transforms,\n        )\n\n\n@dataclasses.dataclass(frozen=True)\nclass TrainConfig:\n    # Name of the config. Must be unique. Will be used to reference this config.\n    name: tyro.conf.Suppress[str]\n    # Project name.\n    project_name: str = \"openpi\"\n    # Experiment name. Will be used to name the metadata and checkpoint directories.\n    exp_name: str = tyro.MISSING\n\n    # Defines the model config. Some attributes (action_dim, action_horizon, and max_token_len) are shared by all models\n    # -- see BaseModelConfig. Specific model implementations (e.g., Pi0Config) inherit from BaseModelConfig and may\n    # define additional attributes.\n    model: _model.BaseModelConfig = dataclasses.field(default_factory=pi0_config.Pi0Config)\n\n    # A weight loader can optionally load (possibly partial) weights from disk after the model is initialized.\n    weight_loader: weight_loaders.WeightLoader = dataclasses.field(default_factory=weight_loaders.NoOpWeightLoader)\n\n    # Optional path to a PyTorch checkpoint to load weights from.\n    pytorch_weight_path: str | None = None\n\n    # Precision for PyTorch training.\n    pytorch_training_precision: Literal[\"bfloat16\", \"float32\"] = \"bfloat16\"\n\n    lr_schedule: _optimizer.LRScheduleConfig = dataclasses.field(default_factory=_optimizer.CosineDecaySchedule)\n    optimizer: _optimizer.OptimizerConfig = dataclasses.field(default_factory=_optimizer.AdamW)\n    ema_decay: float | None = 0.99\n\n    # Specifies which weights should be frozen.\n    freeze_filter: tyro.conf.Suppress[Filter] = dataclasses.field(default_factory=nnx.Nothing)\n\n    # Determines the data to be trained on.\n    data: DataConfigFactory = dataclasses.field(default_factory=FakeDataConfig)\n\n    # Base directory for config assets (e.g., norm stats).\n    assets_base_dir: str = \"./assets\"\n    # Base directory for checkpoints.\n    checkpoint_base_dir: str = \"./checkpoints\"\n\n    # Random seed that will be used by random generators during training.\n    seed: int = 42\n    # Global batch size.\n    batch_size: int = 32\n    # Number of workers to use for the data loader. Increasing this number will speed up data loading but\n    # will increase memory and CPU usage.\n    num_workers: int = 2\n    # Number of train steps (batches) to run.\n    num_train_steps: int = 30_000\n\n    # How often (in steps) to log training metrics.\n    log_interval: int = 100\n    # How often (in steps) to save checkpoints.\n    save_interval: int = 1000\n    # If set, any existing checkpoints matching step % keep_period == 0 will not be deleted.\n    keep_period: int | None = 5000\n\n    # If true, will overwrite the checkpoint directory if it already exists.\n    overwrite: bool = False\n    # If true, will resume training from the last checkpoint.\n    resume: bool = False\n\n    # If true, will enable wandb logging.\n    wandb_enabled: bool = True\n\n    # Used to pass metadata to the policy server.\n    policy_metadata: dict[str, Any] | None = None\n\n    # If the value is greater than 1, FSDP will be enabled and shard across number of specified devices; overall\n    # device memory will be reduced but training could potentially be slower.\n    # eg. if total device is 4 and fsdp devices is 2; then the model will shard to 2 devices and run\n    # data parallel between 2 groups of devices.\n    fsdp_devices: int = 1\n\n    @property\n    def assets_dirs(self) -> pathlib.Path:\n        \"\"\"Get the assets directory for this config.\"\"\"\n        return (pathlib.Path(self.assets_base_dir) / self.name).resolve()\n\n    @property\n    def checkpoint_dir(self) -> pathlib.Path:\n        \"\"\"Get the checkpoint directory for this config.\"\"\"\n        if not self.exp_name:\n            raise ValueError(\"--exp_name must be set\")\n        return (pathlib.Path(self.checkpoint_base_dir) / self.name / self.exp_name).resolve()\n\n    @property\n    def trainable_filter(self) -> nnx.filterlib.Filter:\n        \"\"\"Get the filter for the trainable parameters.\"\"\"\n        return nnx.All(nnx.Param, nnx.Not(self.freeze_filter))\n\n    def __post_init__(self) -> None:\n        if self.resume and self.overwrite:\n            raise ValueError(\"Cannot resume and overwrite at the same time.\")\n\n\n# Use `get_config` if you need to get a config by name in your code.\n_CONFIGS = [\n    #\n    # Inference Aloha configs.\n    #\n    TrainConfig(\n        name=\"pi0_aloha\",\n        model=pi0_config.Pi0Config(),\n        data=LeRobotAlohaDataConfig(\n            assets=AssetsConfig(asset_id=\"trossen\"),\n        ),\n        policy_metadata={\"reset_pose\": [0, -1.5, 1.5, 0, 0, 0]},\n    ),\n    TrainConfig(\n        name=\"pi05_aloha\",\n        model=pi0_config.Pi0Config(pi05=True),\n        data=LeRobotAlohaDataConfig(\n            assets=AssetsConfig(asset_id=\"trossen\"),\n        ),\n        policy_metadata={\"reset_pose\": [0, -1.5, 1.5, 0, 0, 0]},\n    ),\n    TrainConfig(\n        name=\"pi0_aloha_towel\",\n        model=pi0_config.Pi0Config(),\n        data=LeRobotAlohaDataConfig(\n            assets=AssetsConfig(asset_id=\"trossen\"),\n            default_prompt=\"fold the towel\",\n        ),\n        policy_metadata={\"reset_pose\": [0, -1.5, 1.5, 0, 0, 0]},\n    ),\n    TrainConfig(\n        name=\"pi0_aloha_tupperware\",\n        model=pi0_config.Pi0Config(),\n        data=LeRobotAlohaDataConfig(\n            assets=AssetsConfig(asset_id=\"trossen\"),\n            default_prompt=\"open the tupperware and put the food on the plate\",\n        ),\n        policy_metadata={\"reset_pose\": [0, -1.5, 1.5, 0, 0, 0]},\n    ),\n    #\n    # Inference DROID configs.\n    #\n    TrainConfig(\n        name=\"pi0_droid\",\n        model=pi0_config.Pi0Config(action_horizon=10),\n        data=SimpleDataConfig(\n            assets=AssetsConfig(asset_id=\"droid\"),\n            data_transforms=lambda model: _transforms.Group(\n                inputs=[droid_policy.DroidInputs(model_type=ModelType.PI0)],\n                outputs=[droid_policy.DroidOutputs()],\n            ),\n            base_config=DataConfig(\n                prompt_from_task=True,\n            ),\n        ),\n    ),\n    TrainConfig(\n        name=\"pi0_fast_droid\",\n        model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=10),\n        data=SimpleDataConfig(\n            assets=AssetsConfig(asset_id=\"droid\"),\n            data_transforms=lambda model: _transforms.Group(\n                inputs=[droid_policy.DroidInputs(model_type=ModelType.PI0_FAST)],\n                outputs=[droid_policy.DroidOutputs()],\n            ),\n            base_config=DataConfig(\n                prompt_from_task=True,\n            ),\n        ),\n    ),\n    TrainConfig(\n        name=\"pi05_droid\",\n        model=pi0_config.Pi0Config(action_horizon=15, pi05=True),\n        data=SimpleDataConfig(\n            assets=AssetsConfig(asset_id=\"droid\"),\n            data_transforms=lambda model: _transforms.Group(\n                inputs=[droid_policy.DroidInputs(model_type=ModelType.PI05)],\n                outputs=[droid_policy.DroidOutputs()],\n            ),\n            base_config=DataConfig(\n                prompt_from_task=True,\n            ),\n        ),\n    ),\n    #\n    # Fine-tuning Libero configs.\n    #\n    # These train configs define the hyperparameters for fine-tuning the base model on your own dataset.\n    # They are used to define key elements like the dataset you are training on, the base checkpoint you\n    # are using, and other hyperparameters like how many training steps to run or what learning rate to use.\n    # For your own dataset, you can copy this class and modify the dataset name, and data transforms based on\n    # the comments below.\n    TrainConfig(\n        # Change the name to reflect your model and dataset.\n        name=\"pi0_libero\",\n        # Here you define the model config -- In this example we use pi0 as the model\n        # architecture and perform *full* finetuning. in the examples below we show how to modify\n        # this to perform *low-memory* (LORA) finetuning and use pi0-FAST as an alternative architecture.\n        model=pi0_config.Pi0Config(),\n        # Here you define the dataset you are training on. In this example we use the Libero\n        # dataset. For your own dataset, you can change the repo_id to point to your dataset.\n        # Also modify the DataConfig to use the new config you made for your dataset above.\n        data=LeRobotLiberoDataConfig(\n            repo_id=\"physical-intelligence/libero\",\n            base_config=DataConfig(\n                # This flag determines whether we load the prompt (i.e. the task instruction) from the\n                # ``task`` field in the LeRobot dataset. If set to True, the prompt will show up in\n                # a field called ``prompt`` in the input dict. The recommended setting is True.\n                prompt_from_task=True,\n            ),\n            extra_delta_transform=True,\n        ),\n        # Here you define which pre-trained checkpoint you want to load to initialize the model.\n        # This should match the model config you chose above -- i.e. in this case we use the pi0 base model.\n        weight_loader=weight_loaders.CheckpointWeightLoader(\"gs://openpi-assets/checkpoints/pi0_base/params\"),\n        # Below you can define other hyperparameters like the learning rate, number of training steps, etc.\n        # Check the base TrainConfig class for a full list of available hyperparameters.\n        num_train_steps=30_000,\n    ),\n    TrainConfig(\n        name=\"pi0_libero_low_mem_finetune\",\n        # Here is an example of loading a pi0 model for LoRA fine-tuning.\n        model=pi0_config.Pi0Config(paligemma_variant=\"gemma_2b_lora\", action_expert_variant=\"gemma_300m_lora\"),\n        data=LeRobotLiberoDataConfig(\n            repo_id=\"physical-intelligence/libero\",\n            base_config=DataConfig(prompt_from_task=True),\n            extra_delta_transform=True,\n        ),\n        weight_loader=weight_loaders.CheckpointWeightLoader(\"gs://openpi-assets/checkpoints/pi0_base/params\"),\n        num_train_steps=30_000,\n        # The freeze filter defines which parameters should be frozen during training.\n        # We have a convenience function in the model config that returns the default freeze filter\n        # for the given model config for LoRA finetuning. Just make sure it matches the model config\n        # you chose above.\n        freeze_filter=pi0_config.Pi0Config(\n            paligemma_variant=\"gemma_2b_lora\", action_expert_variant=\"gemma_300m_lora\"\n        ).get_freeze_filter(),\n        # Turn off EMA for LoRA finetuning.\n        ema_decay=None,\n    ),\n    TrainConfig(\n        name=\"pi0_fast_libero\",\n        # Here is an example of loading a pi0-FAST model for full finetuning.\n        # Modify action_dim and action_horizon to match your dataset (action horizon is equal to\n        # the desired action chunk length).\n        # The max_token_len is the maximum number of (non-image) tokens the model can handle.\n        # This includes the tokenized prompt, proprioceptive state, and (FAST-tokenized) action tokens.\n        # Choosing this value too small may chop off tokens at the end of your sequence (the code will throw\n        # a warning), while choosing it too large will waste memory (since we pad each batch element to the\n        # max_token_len). A good rule of thumb is to use approx 180 for single-arm robots, and approx 250 for\n        # two-arm robots. Generally, err on the lower side here first, and potentially increase the value if\n        # you see many warnings being thrown during training.\n        model=pi0_fast.Pi0FASTConfig(action_dim=7, action_horizon=10, max_token_len=180),\n        data=LeRobotLiberoDataConfig(\n            repo_id=\"physical-intelligence/libero\",\n            base_config=DataConfig(prompt_from_task=True),\n            extra_delta_transform=True,\n        ),\n        # Note that we load the pi0-FAST base model checkpoint here.\n        weight_loader=weight_loaders.CheckpointWeightLoader(\"gs://openpi-assets/checkpoints/pi0_fast_base/params\"),\n        num_train_steps=30_000,\n    ),\n    TrainConfig(\n        name=\"pi0_fast_libero_low_mem_finetune\",\n        # Here is an example of loading a pi0-FAST model for LoRA finetuning.\n        # For setting action_dim, action_horizon, and max_token_len, see the comments above.\n        model=pi0_fast.Pi0FASTConfig(\n            action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant=\"gemma_2b_lora\"\n        ),\n        data=LeRobotLiberoDataConfig(\n            repo_id=\"physical-intelligence/libero\",\n            base_config=DataConfig(prompt_from_task=True),\n            extra_delta_transform=True,\n        ),\n        weight_loader=weight_loaders.CheckpointWeightLoader(\"gs://openpi-assets/checkpoints/pi0_fast_base/params\"),\n        num_train_steps=30_000,\n        # Again, make sure to match the model config above when extracting the freeze filter\n        # that specifies which parameters should be frozen during LoRA finetuning.\n        freeze_filter=pi0_fast.Pi0FASTConfig(\n            action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant=\"gemma_2b_lora\"\n        ).get_freeze_filter(),\n        # Turn off EMA for LoRA finetuning.\n        ema_decay=None,\n    ),\n    TrainConfig(\n        name=\"pi05_libero\",\n        model=pi0_config.Pi0Config(pi05=True, action_horizon=10, discrete_state_input=False),\n        data=LeRobotLiberoDataConfig(\n            repo_id=\"physical-intelligence/libero\",\n            base_config=DataConfig(prompt_from_task=True),\n            extra_delta_transform=False,\n        ),\n        batch_size=256,\n        lr_schedule=_optimizer.CosineDecaySchedule(\n            warmup_steps=10_000,\n            peak_lr=5e-5,\n            decay_steps=1_000_000,\n            decay_lr=5e-5,\n        ),\n        optimizer=_optimizer.AdamW(clip_gradient_norm=1.0),\n        ema_decay=0.999,\n        weight_loader=weight_loaders.CheckpointWeightLoader(\"gs://openpi-assets/checkpoints/pi05_base/params\"),\n        pytorch_weight_path=\"/path/to/your/pytorch_weight_path\",\n        num_train_steps=30_000,\n    ),\n    #\n    # Fine-tuning Aloha configs.\n    #\n    # This is a test config that is used to illustate how train on a custom LeRobot dataset.\n    # For instructions on how to convert and train on your own Aloha dataset see examples/aloha_real/README.md\n    TrainConfig(\n        name=\"pi0_aloha_pen_uncap\",\n        model=pi0_config.Pi0Config(),\n        data=LeRobotAlohaDataConfig(\n            repo_id=\"physical-intelligence/aloha_pen_uncap_diverse\",\n            assets=AssetsConfig(\n                assets_dir=\"gs://openpi-assets/checkpoints/pi0_base/assets\",\n                asset_id=\"trossen\",\n            ),\n            default_prompt=\"uncap the pen\",\n            repack_transforms=_transforms.Group(\n                inputs=[\n                    _transforms.RepackTransform(\n                        {\n                            \"images\": {\n                                \"cam_high\": \"observation.images.cam_high\",\n                                \"cam_left_wrist\": \"observation.images.cam_left_wrist\",\n                                \"cam_right_wrist\": \"observation.images.cam_right_wrist\",\n                            },\n                            \"state\": \"observation.state\",\n                            \"actions\": \"action\",\n                        }\n                    )\n                ]\n            ),\n        ),\n        weight_loader=weight_loaders.CheckpointWeightLoader(\"gs://openpi-assets/checkpoints/pi0_base/params\"),\n        num_train_steps=20_000,\n    ),\n    TrainConfig(\n        name=\"pi05_aloha_pen_uncap\",\n        model=pi0_config.Pi0Config(pi05=True),\n        data=LeRobotAlohaDataConfig(\n            repo_id=\"physical-intelligence/aloha_pen_uncap_diverse\",\n            assets=AssetsConfig(\n                assets_dir=\"gs://openpi-assets/checkpoints/pi05_base/assets\",\n                asset_id=\"trossen\",\n            ),\n            default_prompt=\"uncap the pen\",\n            repack_transforms=_transforms.Group(\n                inputs=[\n                    _transforms.RepackTransform(\n                        {\n                            \"images\": {\n                                \"cam_high\": \"observation.images.cam_high\",\n                                \"cam_left_wrist\": \"observation.images.cam_left_wrist\",\n                                \"cam_right_wrist\": \"observation.images.cam_right_wrist\",\n                            },\n                            \"state\": \"observation.state\",\n                            \"actions\": \"action\",\n                        }\n                    )\n                ]\n            ),\n        ),\n        weight_loader=weight_loaders.CheckpointWeightLoader(\"gs://openpi-assets/checkpoints/pi05_base/params\"),\n        num_train_steps=20_000,\n        batch_size=64,\n    ),\n    #\n    # Fine-tuning DROID configs.\n    #\n    TrainConfig(\n        # This config is for fine-tuning pi0-FAST-base on the *full* DROID dataset.\n        # We use RLDS data loading to make training on this large dataset tractable.\n        # For fine-tuning on your own DROID dataset, see below.\n        name=\"pi0_fast_full_droid_finetune\",\n        model=pi0_fast.Pi0FASTConfig(\n            action_dim=8,\n            action_horizon=16,\n            max_token_len=180,\n        ),\n        data=RLDSDroidDataConfig(\n            repo_id=\"droid\",\n            # Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory).\n            rlds_data_dir=\"<path_to_droid_rlds_dataset>\",\n            action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION,\n        ),\n        weight_loader=weight_loaders.CheckpointWeightLoader(\"gs://openpi-assets/checkpoints/pi0_fast_base/params\"),\n        lr_schedule=_optimizer.CosineDecaySchedule(\n            warmup_steps=1_000,\n            peak_lr=5e-5,\n            decay_steps=1_000_000,\n            decay_lr=5e-5,\n        ),\n        num_train_steps=100_000,  # 100k steps should be sufficient, takes ~2 days on 8x H100s\n        batch_size=256,\n        log_interval=100,\n        save_interval=5000,\n        keep_period=20_000,\n        num_workers=0,  # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally\n    ),\n    TrainConfig(\n        # This config is for fine-tuning pi05 on the *full* DROID dataset.\n        # We use RLDS data loading to make training on this large dataset tractable.\n        # For fine-tuning on your own DROID dataset, see below.\n        name=\"pi05_full_droid_finetune\",\n        model=pi0_config.Pi0Config(\n            pi05=True,\n            action_dim=32,\n            action_horizon=16,\n        ),\n        data=RLDSDroidDataConfig(\n            repo_id=\"droid\",\n            # Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory).\n            rlds_data_dir=\"/mnt/pi-data/kevin\",\n            action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION,\n            assets=AssetsConfig(\n                assets_dir=\"gs://openpi-assets/checkpoints/pi05_base/assets/\",\n                asset_id=\"droid\",\n            ),\n        ),\n        weight_loader=weight_loaders.CheckpointWeightLoader(\"gs://openpi-assets/checkpoints/pi05_base/params\"),\n        lr_schedule=_optimizer.CosineDecaySchedule(\n            warmup_steps=1_000,\n            peak_lr=5e-5,\n            decay_steps=1_000_000,\n            decay_lr=5e-5,\n        ),\n        num_train_steps=100_000,\n        batch_size=256,\n        log_interval=100,\n        save_interval=5000,\n        keep_period=10_000,\n        num_workers=0,  # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally\n    ),\n    TrainConfig(\n        # This config is for fine-tuning pi05-DROID on a custom (smaller) DROID dataset.\n        # Here, we use LeRobot data format (like for all other fine-tuning examples)\n        # To convert your custom DROID dataset (<10s of hours) to LeRobot format, see examples/droid/convert_droid_data_to_lerobot.py\n        name=\"pi05_droid_finetune\",\n        model=pi0_config.Pi0Config(\n            pi05=True,\n            action_dim=32,  # pi05 is trained with 32-dim actions\n            action_horizon=16,\n        ),\n        data=LeRobotDROIDDataConfig(\n            # Replace with your custom DROID LeRobot dataset repo id.\n            repo_id=\"your_hf_username/my_droid_dataset\",\n            base_config=DataConfig(prompt_from_task=True),\n            assets=AssetsConfig(\n                # Important: reuse the original DROID norm stats during fine-tuning!\n                assets_dir=\"gs://openpi-assets/checkpoints/pi05_droid/assets\",\n                asset_id=\"droid\",\n            ),\n        ),\n        weight_loader=weight_loaders.CheckpointWeightLoader(\"gs://openpi-assets/checkpoints/pi05_droid/params\"),\n        num_train_steps=20_000,\n        batch_size=32,\n    ),\n    #\n    # ALOHA Sim configs. This config is used to demonstrate how to train on a simple simulated environment.\n    #\n    TrainConfig(\n        name=\"pi0_aloha_sim\",\n        model=pi0_config.Pi0Config(),\n        data=LeRobotAlohaDataConfig(\n            repo_id=\"lerobot/aloha_sim_transfer_cube_human\",\n            default_prompt=\"Transfer cube\",\n            use_delta_joint_actions=False,\n        ),\n        weight_loader=weight_loaders.CheckpointWeightLoader(\"gs://openpi-assets/checkpoints/pi0_base/params\"),\n        num_train_steps=20_000,\n    ),\n    #\n    # Debugging configs.\n    #\n    TrainConfig(\n        name=\"debug\",\n        data=FakeDataConfig(),\n        batch_size=2,\n        model=pi0_config.Pi0Config(paligemma_variant=\"dummy\", action_expert_variant=\"dummy\"),\n        save_interval=100,\n        overwrite=True,\n        exp_name=\"debug\",\n        num_train_steps=10,\n        wandb_enabled=False,\n    ),\n    TrainConfig(\n        name=\"debug_restore\",\n        data=FakeDataConfig(),\n        batch_size=2,\n        model=pi0_config.Pi0Config(paligemma_variant=\"dummy\", action_expert_variant=\"dummy\"),\n        weight_loader=weight_loaders.CheckpointWeightLoader(\"./checkpoints/debug/debug/9/params\"),\n        overwrite=True,\n        exp_name=\"debug\",\n        num_train_steps=10,\n        wandb_enabled=False,\n    ),\n    TrainConfig(\n        name=\"debug_pi05\",\n        model=pi0_config.Pi0Config(pi05=True, paligemma_variant=\"dummy\", action_expert_variant=\"dummy\"),\n        data=FakeDataConfig(),\n        batch_size=2,\n        num_train_steps=10,\n        overwrite=True,\n        exp_name=\"debug_pi05\",\n        wandb_enabled=False,\n    ),\n    # RoboArena & PolaRiS configs.\n    *roboarena_config.get_roboarena_configs(),\n    *polaris_config.get_polaris_configs(),\n]\n\nif len({config.name for config in _CONFIGS}) != len(_CONFIGS):\n    raise ValueError(\"Config names must be unique.\")\n_CONFIGS_DICT = {config.name: config for config in _CONFIGS}\n\n\ndef cli() -> TrainConfig:\n    return tyro.extras.overridable_config_cli({k: (k, v) for k, v in _CONFIGS_DICT.items()})\n\n\ndef get_config(config_name: str) -> TrainConfig:\n    \"\"\"Get a config by name.\"\"\"\n    if config_name not in _CONFIGS_DICT:\n        closest = difflib.get_close_matches(config_name, _CONFIGS_DICT.keys(), n=1, cutoff=0.0)\n        closest_str = f\" Did you mean '{closest[0]}'? \" if closest else \"\"\n        raise ValueError(f\"Config '{config_name}' not found.{closest_str}\")\n\n    return _CONFIGS_DICT[config_name]\n"
  },
  {
    "path": "src/openpi/training/data_loader.py",
    "content": "from collections.abc import Iterator, Sequence\nimport logging\nimport multiprocessing\nimport os\nimport typing\nfrom typing import Literal, Protocol, SupportsIndex, TypeVar\n\nimport jax\nimport jax.numpy as jnp\nimport lerobot.common.datasets.lerobot_dataset as lerobot_dataset\nimport numpy as np\nimport torch\n\nimport openpi.models.model as _model\nimport openpi.training.config as _config\nfrom openpi.training.droid_rlds_dataset import DroidRldsDataset\nimport openpi.transforms as _transforms\n\nT_co = TypeVar(\"T_co\", covariant=True)\n\n\nclass Dataset(Protocol[T_co]):\n    \"\"\"Interface for a dataset with random access.\"\"\"\n\n    def __getitem__(self, index: SupportsIndex) -> T_co:\n        raise NotImplementedError(\"Subclasses of Dataset should implement __getitem__.\")\n\n    def __len__(self) -> int:\n        raise NotImplementedError(\"Subclasses of Dataset should implement __len__.\")\n\n\nclass IterableDataset(Protocol[T_co]):\n    \"\"\"Interface for an iterable dataset.\"\"\"\n\n    def __iter__(self) -> Iterator[T_co]:\n        raise NotImplementedError(\"Subclasses of IterableDataset should implement __iter__.\")\n\n    def __len__(self) -> int:\n        raise NotImplementedError(\"Subclasses of Dataset should implement __len__.\")\n\n\nclass DataLoader(Protocol[T_co]):\n    \"\"\"Interface for a data loader.\"\"\"\n\n    def data_config(self) -> _config.DataConfig:\n        \"\"\"Get the data config for this data loader.\"\"\"\n        raise NotImplementedError(\"Subclasses of DataLoader should implement data_config.\")\n\n    def __iter__(self) -> Iterator[T_co]:\n        raise NotImplementedError(\"Subclasses of DataLoader should implement __iter__.\")\n\n\nclass TransformedDataset(Dataset[T_co]):\n    def __init__(self, dataset: Dataset, transforms: Sequence[_transforms.DataTransformFn]):\n        self._dataset = dataset\n        self._transform = _transforms.compose(transforms)\n\n    def __getitem__(self, index: SupportsIndex) -> T_co:\n        return self._transform(self._dataset[index])\n\n    def __len__(self) -> int:\n        return len(self._dataset)\n\n\nclass IterableTransformedDataset(IterableDataset[T_co]):\n    def __init__(\n        self,\n        dataset: IterableDataset,\n        transforms: Sequence[_transforms.DataTransformFn],\n        *,\n        is_batched: bool = False,\n    ):\n        self._dataset = dataset\n        self._transform = _transforms.compose(transforms)\n        self._is_batched = is_batched\n\n    def __iter__(self):\n        for sample in self._dataset:\n            if self._is_batched:\n                # Transforms are designed to be applied to individual samples. So we need to split the batch into\n                # individual samples and apply the transform to each sample individually.\n                batch_size = next(v.shape[0] for v in sample.values())\n\n                # Split batch into individual samples using tree_map\n                individual_samples = [jax.tree.map(lambda x: x[i], sample) for i in range(batch_size)]  # noqa: B023\n\n                # Transform each sample\n                transformed = [self._transform(s) for s in individual_samples]\n\n                # Recombine batch with tree_map\n                yield jax.tree.map(lambda *x: np.stack(x, axis=0), *transformed)\n            else:\n                yield self._transform(sample)\n\n    def __len__(self) -> int:\n        return len(self._dataset)\n\n\nclass FakeDataset(Dataset):\n    def __init__(self, model_config: _model.BaseModelConfig, num_samples: int):\n        self._num_samples = num_samples\n        self._observation_spec, self._action_spec = model_config.inputs_spec()\n\n    def __getitem__(self, index: SupportsIndex) -> dict:\n        rng = jax.random.key(index.__index__())\n\n        def make_from_spec(spec: jax.ShapeDtypeStruct):\n            nonlocal rng\n            rng, data_rng = jax.random.split(rng)\n            # Remove the batch dimension.\n            shape = spec.shape[1:]\n            if spec.dtype == jnp.float32:\n                return jax.random.uniform(data_rng, shape=shape, minval=-1.0, maxval=1.0)\n            if spec.dtype == jnp.int32:\n                return jax.random.randint(data_rng, shape=shape, minval=0, maxval=2048)\n            return jnp.zeros(shape=shape, dtype=spec.dtype)\n\n        observation = jax.tree.map(make_from_spec, self._observation_spec)\n        action = jax.tree.map(make_from_spec, self._action_spec)\n\n        return {\n            **observation.to_dict(),\n            \"actions\": action,\n        }\n\n    def __len__(self) -> int:\n        return self._num_samples\n\n\ndef create_torch_dataset(\n    data_config: _config.DataConfig, action_horizon: int, model_config: _model.BaseModelConfig\n) -> Dataset:\n    \"\"\"Create a dataset for training.\"\"\"\n    repo_id = data_config.repo_id\n    if repo_id is None:\n        raise ValueError(\"Repo ID is not set. Cannot create dataset.\")\n    if repo_id == \"fake\":\n        return FakeDataset(model_config, num_samples=1024)\n\n    dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id)\n    dataset = lerobot_dataset.LeRobotDataset(\n        data_config.repo_id,\n        delta_timestamps={\n            key: [t / dataset_meta.fps for t in range(action_horizon)] for key in data_config.action_sequence_keys\n        },\n    )\n\n    if data_config.prompt_from_task:\n        dataset = TransformedDataset(dataset, [_transforms.PromptFromLeRobotTask(dataset_meta.tasks)])\n\n    return dataset\n\n\ndef create_rlds_dataset(\n    data_config: _config.DataConfig,\n    action_horizon: int,\n    batch_size: int,\n    *,\n    shuffle: bool = False,\n) -> Dataset:\n    # At the moment, we only support DROID for RLDS datasets.\n    return DroidRldsDataset(\n        data_dir=data_config.rlds_data_dir,\n        batch_size=batch_size,\n        shuffle=shuffle,\n        action_chunk_size=action_horizon,\n        action_space=data_config.action_space,\n        datasets=data_config.datasets,\n    )\n\n\ndef transform_dataset(dataset: Dataset, data_config: _config.DataConfig, *, skip_norm_stats: bool = False) -> Dataset:\n    \"\"\"Transform the dataset by applying the data transforms.\"\"\"\n    norm_stats = {}\n    if data_config.repo_id != \"fake\" and not skip_norm_stats:\n        if data_config.norm_stats is None:\n            raise ValueError(\n                \"Normalization stats not found. \"\n                \"Make sure to run `scripts/compute_norm_stats.py --config-name=<your-config>`.\"\n            )\n        norm_stats = data_config.norm_stats\n\n    return TransformedDataset(\n        dataset,\n        [\n            *data_config.repack_transforms.inputs,\n            *data_config.data_transforms.inputs,\n            _transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),\n            *data_config.model_transforms.inputs,\n        ],\n    )\n\n\ndef transform_iterable_dataset(\n    dataset: IterableDataset,\n    data_config: _config.DataConfig,\n    *,\n    skip_norm_stats: bool = False,\n    is_batched: bool = False,\n) -> IterableDataset:\n    \"\"\"Transform the dataset by applying the data transforms.\"\"\"\n    norm_stats = {}\n    if data_config.repo_id != \"fake\" and not skip_norm_stats:\n        if data_config.norm_stats is None:\n            raise ValueError(\n                \"Normalization stats not found. \"\n                \"Make sure to run `scripts/compute_norm_stats.py --config-name=<your-config>`.\"\n            )\n        norm_stats = data_config.norm_stats\n\n    return IterableTransformedDataset(\n        dataset,\n        [\n            *data_config.repack_transforms.inputs,\n            *data_config.data_transforms.inputs,\n            _transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),\n            *data_config.model_transforms.inputs,\n        ],\n        is_batched=is_batched,\n    )\n\n\ndef create_data_loader(\n    config: _config.TrainConfig,\n    *,\n    sharding: jax.sharding.Sharding | None = None,\n    shuffle: bool = False,\n    num_batches: int | None = None,\n    skip_norm_stats: bool = False,\n    framework: Literal[\"jax\", \"pytorch\"] = \"jax\",\n) -> DataLoader[tuple[_model.Observation, _model.Actions]]:\n    \"\"\"Create a data loader for training.\n\n    Args:\n        config: The training configuration.\n        sharding: The sharding to use for the data loader (JAX only).\n        shuffle: Whether to shuffle the data.\n        num_batches: Determines the number of batches to return.\n        skip_norm_stats: Whether to skip data normalization.\n        framework: The framework to use (\"jax\" or \"pytorch\").\n    \"\"\"\n    data_config = config.data.create(config.assets_dirs, config.model)\n    logging.info(f\"data_config: {data_config}\")\n\n    if data_config.rlds_data_dir is not None:\n        return create_rlds_data_loader(\n            data_config,\n            action_horizon=config.model.action_horizon,\n            batch_size=config.batch_size,\n            sharding=sharding,\n            shuffle=shuffle,\n            num_batches=num_batches,\n            skip_norm_stats=skip_norm_stats,\n            framework=framework,\n        )\n    return create_torch_data_loader(\n        data_config,\n        model_config=config.model,\n        action_horizon=config.model.action_horizon,\n        batch_size=config.batch_size,\n        sharding=sharding,\n        shuffle=shuffle,\n        num_batches=num_batches,\n        num_workers=config.num_workers,\n        seed=config.seed,\n        skip_norm_stats=skip_norm_stats,\n        framework=framework,\n    )\n\n\ndef create_torch_data_loader(\n    data_config: _config.DataConfig,\n    model_config: _model.BaseModelConfig,\n    action_horizon: int,\n    batch_size: int,\n    *,\n    sharding: jax.sharding.Sharding | None = None,\n    skip_norm_stats: bool = False,\n    shuffle: bool = False,\n    num_batches: int | None = None,\n    num_workers: int = 0,\n    seed: int = 0,\n    framework: str = \"jax\",\n) -> DataLoader[tuple[_model.Observation, _model.Actions]]:\n    \"\"\"Create a data loader for training.\n\n    Args:\n        data_config: The data configuration.\n        action_horizon: The action horizon.\n        batch_size: The batch size.\n        sharding: The sharding to use for the data loader. If None, the data loader will\n            use a single device sharding.\n        skip_norm_stats: Whether to skip data normalization.\n        shuffle: Whether to shuffle the data.\n        num_batches: Determines the number of batches to return. If the number exceeds the\n            number of batches in the dataset, the data loader will loop over the dataset.\n            If not provided, will iterate over the dataset indefinitely.\n        num_workers: The number of worker processes to use. If zero, the data loader will\n            execute in the main process.\n        seed: The seed to use for shuffling the data.\n    \"\"\"\n    dataset = create_torch_dataset(data_config, action_horizon, model_config)\n    dataset = transform_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats)\n\n    # Use TorchDataLoader for both frameworks\n    # For PyTorch DDP, create DistributedSampler and divide batch size by world size\n    # For JAX, divide by process count\n    sampler = None\n    if framework == \"pytorch\":\n        if torch.distributed.is_initialized():\n            sampler = torch.utils.data.distributed.DistributedSampler(\n                dataset,\n                num_replicas=torch.distributed.get_world_size(),\n                rank=torch.distributed.get_rank(),\n                shuffle=shuffle,\n                drop_last=True,\n            )\n            local_batch_size = batch_size // torch.distributed.get_world_size()\n        else:\n            local_batch_size = batch_size\n    else:\n        local_batch_size = batch_size // jax.process_count()\n\n    logging.info(f\"local_batch_size: {local_batch_size}\")\n    data_loader = TorchDataLoader(\n        dataset,\n        local_batch_size=local_batch_size,\n        sharding=None if framework == \"pytorch\" else sharding,\n        shuffle=(sampler is None and shuffle),  # Don't shuffle if using sampler\n        sampler=sampler,\n        num_batches=num_batches,\n        num_workers=num_workers,\n        seed=seed,\n        framework=framework,\n    )\n\n    return DataLoaderImpl(data_config, data_loader)\n\n\ndef create_rlds_data_loader(\n    data_config: _config.DataConfig,\n    action_horizon: int,\n    batch_size: int,\n    *,\n    sharding: jax.sharding.Sharding | None = None,\n    skip_norm_stats: bool = False,\n    shuffle: bool = False,\n    num_batches: int | None = None,\n    framework: str = \"jax\",\n) -> DataLoader[tuple[_model.Observation, _model.Actions]]:\n    \"\"\"Create an RLDS data loader for training.\n\n    Note: This data loader requires some extra dependencies -- see examples/droid/README_train.md\n\n    Args:\n        data_config: The data configuration.\n        action_horizon: The action horizon.\n        batch_size: The batch size.\n        sharding: The sharding to use for the data loader. If None, the data loader will\n            use a single device sharding.\n        skip_norm_stats: Whether to skip data normalization.\n        shuffle: Whether to shuffle the data.\n        num_batches: Determines the number of batches to return. If the number exceeds the\n            number of batches in the dataset, the data loader will loop over the dataset.\n            If not provided, will iterate over the dataset indefinitely.\n    \"\"\"\n    if framework == \"pytorch\":\n        raise NotImplementedError(\"PyTorch RLDS data loader is not supported yet\")\n    dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle)\n    dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True)\n\n    data_loader = RLDSDataLoader(\n        dataset,\n        sharding=sharding,\n        num_batches=num_batches,\n    )\n\n    return DataLoaderImpl(data_config, data_loader)\n\n\nclass TorchDataLoader:\n    \"\"\"Torch data loader implementation.\"\"\"\n\n    def __init__(\n        self,\n        dataset,\n        local_batch_size: int,\n        *,\n        sharding: jax.sharding.Sharding | None = None,\n        shuffle: bool = False,\n        sampler: torch.utils.data.Sampler | None = None,\n        num_batches: int | None = None,\n        num_workers: int = 0,\n        seed: int = 0,\n        framework: str = \"jax\",\n    ):\n        \"\"\"Create a PyTorch data loader.\n\n        Args:\n            dataset: The dataset to load.\n            local_batch_size: The local batch size for each process.\n            sharding: The sharding to use for the data loader.\n            shuffle: Whether to shuffle the data.\n            num_batches: If provided, determines the number of returned batches. If the\n                number is larger than the number of batches in the dataset, the data loader\n                will loop over the dataset. If not provided, will iterate over the dataset\n                indefinitely.\n            num_workers: The number of worker processes to use. If zero, the data loader will\n                execute in the main process.\n            seed: The seed to use for shuffling the data.\n        \"\"\"\n        if jax.process_count() > 1:\n            raise NotImplementedError(\"Data loading with multiple processes is not supported.\")\n\n        if len(dataset) < local_batch_size:\n            raise ValueError(f\"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).\")\n\n        # Store sharding - None for PyTorch, JAX sharding for JAX\n        self._sharding = sharding\n        if sharding is None and framework == \"jax\":\n            # Use data parallel sharding by default for JAX only.\n            self._sharding = jax.sharding.NamedSharding(\n                jax.sharding.Mesh(jax.devices(), (\"B\",)),\n                jax.sharding.PartitionSpec(\"B\"),\n            )\n        self._num_batches = num_batches\n\n        mp_context = None\n        if num_workers > 0:\n            mp_context = multiprocessing.get_context(\"spawn\")\n\n        generator = torch.Generator()\n        generator.manual_seed(seed)\n        self._data_loader = torch.utils.data.DataLoader(\n            typing.cast(torch.utils.data.Dataset, dataset),\n            batch_size=local_batch_size,\n            shuffle=(sampler is None and shuffle),  # Don't shuffle if using sampler\n            sampler=sampler,\n            num_workers=num_workers,\n            multiprocessing_context=mp_context,\n            persistent_workers=num_workers > 0,\n            collate_fn=_collate_fn,\n            worker_init_fn=_worker_init_fn,\n            drop_last=True,\n            generator=generator,\n        )\n\n    @property\n    def torch_loader(self) -> torch.utils.data.DataLoader:\n        return self._data_loader\n\n    def __iter__(self):\n        num_items = 0\n        while True:\n            data_iter = iter(self._data_loader)\n            while True:\n                if self._num_batches is not None and num_items >= self._num_batches:\n                    return\n                try:\n                    batch = next(data_iter)\n                except StopIteration:\n                    break  # We've exhausted the dataset. Create a new iterator and start over.\n                num_items += 1\n                # For JAX, convert to sharded arrays; for PyTorch, return torch tensors\n                if self._sharding is not None:\n                    yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch)\n                else:\n                    yield jax.tree.map(torch.as_tensor, batch)\n\n\ndef _collate_fn(items):\n    \"\"\"Collate the batch elements into batched numpy arrays.\"\"\"\n    # Make sure to convert to numpy arrays before stacking since some of the incoming elements\n    # may be JAX arrays.\n    return jax.tree.map(lambda *xs: np.stack([np.asarray(x) for x in xs], axis=0), *items)\n\n\ndef _worker_init_fn(worker_id: int) -> None:\n    \"\"\"Tell JAX inside the worker process not to preallocate the GPU memory.\"\"\"\n    # NOTE: This is called after jax is imported inside the worker process. This\n    # means that this approach will not work for selecting the backend.\n    os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n    os.environ[\"XLA_PYTHON_CLIENT_ALLOCATOR\"] = \"platform\"\n\n\nclass RLDSDataLoader:\n    \"\"\"Shallow wrapper around the DROID data loader to make it compatible with openpi.\n\n    All batching already happens in the DROID dataset, so we don't need to do anything here.\n    \"\"\"\n\n    def __init__(\n        self,\n        dataset: DroidRldsDataset,\n        *,\n        sharding: jax.sharding.Sharding | None = None,\n        num_batches: int | None = None,\n    ):\n        self._dataset = dataset\n        self._num_batches = num_batches\n\n        if jax.process_count() > 1:\n            raise NotImplementedError(\"Data loading with multiple processes is not supported.\")\n\n        if sharding is None:\n            # Use data parallel sharding by default.\n            sharding = jax.sharding.NamedSharding(\n                jax.sharding.Mesh(jax.devices(), (\"B\",)),\n                jax.sharding.PartitionSpec(\"B\"),\n            )\n\n        self._sharding = sharding\n        self._num_batches = num_batches\n\n    def __iter__(self):\n        num_items = 0\n        while True:\n            data_iter = iter(self._dataset)\n            while True:\n                if self._num_batches is not None and num_items >= self._num_batches:\n                    return\n                try:\n                    batch = next(data_iter)\n                except StopIteration:\n                    break  # We've exhausted the dataset. Create a new iterator and start over.\n                num_items += 1\n                yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch)\n\n\nclass DataLoaderImpl(DataLoader):\n    def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader):\n        self._data_config = data_config\n        self._data_loader = data_loader\n\n    def data_config(self) -> _config.DataConfig:\n        return self._data_config\n\n    def __iter__(self):\n        for batch in self._data_loader:\n            yield _model.Observation.from_dict(batch), batch[\"actions\"]\n"
  },
  {
    "path": "src/openpi/training/data_loader_test.py",
    "content": "import dataclasses\n\nimport jax\n\nfrom openpi.models import pi0_config\nfrom openpi.training import config as _config\nfrom openpi.training import data_loader as _data_loader\n\n\ndef test_torch_data_loader():\n    config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48)\n    dataset = _data_loader.FakeDataset(config, 16)\n\n    loader = _data_loader.TorchDataLoader(\n        dataset,\n        local_batch_size=4,\n        num_batches=2,\n    )\n    batches = list(loader)\n\n    assert len(batches) == 2\n    for batch in batches:\n        assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch))\n\n\ndef test_torch_data_loader_infinite():\n    config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48)\n    dataset = _data_loader.FakeDataset(config, 4)\n\n    loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4)\n    data_iter = iter(loader)\n\n    for _ in range(10):\n        _ = next(data_iter)\n\n\ndef test_torch_data_loader_parallel():\n    config = pi0_config.Pi0Config(action_dim=24, action_horizon=50, max_token_len=48)\n    dataset = _data_loader.FakeDataset(config, 10)\n\n    loader = _data_loader.TorchDataLoader(dataset, local_batch_size=4, num_batches=2, num_workers=2)\n    batches = list(loader)\n\n    assert len(batches) == 2\n\n    for batch in batches:\n        assert all(x.shape[0] == 4 for x in jax.tree.leaves(batch))\n\n\ndef test_with_fake_dataset():\n    config = _config.get_config(\"debug\")\n\n    loader = _data_loader.create_data_loader(config, skip_norm_stats=True, num_batches=2)\n    batches = list(loader)\n\n    assert len(batches) == 2\n\n    for batch in batches:\n        assert all(x.shape[0] == config.batch_size for x in jax.tree.leaves(batch))\n\n    for _, actions in batches:\n        assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim)\n\n\ndef test_with_real_dataset():\n    config = _config.get_config(\"pi0_aloha_sim\")\n    config = dataclasses.replace(config, batch_size=4)\n\n    loader = _data_loader.create_data_loader(\n        config,\n        # Skip since we may not have the data available.\n        skip_norm_stats=True,\n        num_batches=2,\n        shuffle=True,\n    )\n    # Make sure that we can get the data config.\n    assert loader.data_config().repo_id == config.data.repo_id\n\n    batches = list(loader)\n\n    assert len(batches) == 2\n\n    for _, actions in batches:\n        assert actions.shape == (config.batch_size, config.model.action_horizon, config.model.action_dim)\n"
  },
  {
    "path": "src/openpi/training/droid_rlds_dataset.py",
    "content": "\"\"\"\nRLDS-based data loader for DROID.\nWhile openpi typically uses LeRobot's data loader, it is not currently scalable enough for larger datasets like DROID.\nThus, we provide a data loader example here that uses the RLDS data format.\nThe data loader also applies a few DROID-specific data filters / transformations.\n\"\"\"\n\nfrom collections.abc import Sequence\nimport dataclasses\nfrom enum import Enum\nfrom enum import auto\nimport json\nimport logging\nfrom pathlib import Path\n\nimport tqdm\n\nimport openpi.shared.download as download\n\n\nclass DroidActionSpace(Enum):\n    \"\"\"Action space for DROID dataset.\"\"\"\n\n    JOINT_POSITION = auto()\n    JOINT_VELOCITY = auto()\n\n\n@dataclasses.dataclass\nclass RLDSDataset:\n    name: str\n    version: str\n    weight: float\n    filter_dict_path: str | None = None\n\n\nclass DroidRldsDataset:\n    def __init__(\n        self,\n        data_dir: str,\n        batch_size: int,\n        datasets: Sequence[RLDSDataset],\n        *,  # Force keyword-only arguments\n        shuffle: bool = True,\n        action_chunk_size: int = 16,\n        # We default to joint position actions, since they allow policy evaluation in simulation.\n        action_space: DroidActionSpace = DroidActionSpace.JOINT_POSITION,\n        max_loaded_steps_per_episode: int = 100,\n        # Reduce this if you are running out of memory, but careful -- below ~100k shuffling is not sufficiently random.\n        shuffle_buffer_size: int = 250_000,\n        num_parallel_reads: int = -1,  # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level\n        num_parallel_calls: int = -1,  # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level\n    ):\n        # Import tensorflow here to not make it mandatory in case RLDS data loader is not used.\n        import dlimp as dl\n        import tensorflow as tf\n        import tensorflow_datasets as tfds\n\n        # Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch / JAX)\n        tf.config.set_visible_devices([], \"GPU\")\n\n        # Ensure dataset weights sum to 1.0\n        assert sum(dataset.weight for dataset in datasets) == 1.0, \"Dataset weights must sum to 1.0\"\n\n        def prepare_single_dataset(dataset_cfg: RLDSDataset):\n            # ds_name, version = dataset_name.split(\":\")\n            ds_name, version = dataset_cfg.name, dataset_cfg.version\n            builder = tfds.builder(ds_name, data_dir=data_dir, version=version)\n            dataset = dl.DLataset.from_rlds(\n                builder, split=\"train\", shuffle=shuffle, num_parallel_reads=num_parallel_reads\n            )\n\n            # Filter out any unsuccessful trajectories -- we use the file name to check this\n            dataset = dataset.filter(\n                lambda traj: tf.strings.regex_full_match(\n                    traj[\"traj_metadata\"][\"episode_metadata\"][\"file_path\"][0], \".*success.*\"\n                )\n            )\n\n            # Repeat dataset so we never run out of data.\n            dataset = dataset.repeat()\n\n            # Load the filter dictionary if provided.\n            # The filter dictionary is a JSON file that maps episode keys to ranges of frames to sample\n            # (e.g.,\n            # {\n            #     \"<episode key>\": [[0, 100], [200, 300]]\n            # }\n            # means keep frames 0-99 and 200-299).\n\n            filter_dict_path = dataset_cfg.filter_dict_path\n            if filter_dict_path is not None:\n                cached_filter_dict_path = download.maybe_download(filter_dict_path)\n                with Path(cached_filter_dict_path).open(\"r\") as f:\n                    filter_dict = json.load(f)\n                logging.info(f\"Using filter dictionary with {len(filter_dict)} episodes\")\n\n                keys_tensor = []\n                values_tensor = []\n\n                for episode_key, ranges in tqdm.tqdm(filter_dict.items(), desc=\"Creating idle filter hash table...\"):\n                    for start, end in ranges:\n                        for t in range(start, end):\n                            frame_key = f\"{episode_key}--{t}\"\n                            keys_tensor.append(frame_key)\n                            values_tensor.append(True)\n                self.filter_table = tf.lookup.StaticHashTable(\n                    tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor), default_value=False\n                )\n                logging.info(\"Filter hash table initialized\")\n            else:\n                self.filter_table = tf.lookup.StaticHashTable(\n                    tf.lookup.KeyValueTensorInitializer([\"\"], [True]), default_value=True\n                )\n\n            def restructure(traj):\n                \"\"\"Reformat observation and action keys, sample language instruction.\"\"\"\n                # Important: we use joint *position* action space -- easier to simulate!\n                actions = tf.concat(\n                    (\n                        (\n                            traj[\"action_dict\"][\"joint_position\"]\n                            if action_space == DroidActionSpace.JOINT_POSITION\n                            else traj[\"action_dict\"][\"joint_velocity\"]\n                        ),\n                        traj[\"action_dict\"][\"gripper_position\"],\n                    ),\n                    axis=-1,\n                )\n                # Randomly samples one of the two exterior images in DROID during training (we only train with one at a time).\n                # Note: the \"left\" refers to the left camera in the stereo pair, we only train on the left camera.\n                exterior_img = tf.cond(\n                    tf.random.uniform(shape=[]) > 0.5,\n                    lambda: traj[\"observation\"][\"exterior_image_1_left\"],\n                    lambda: traj[\"observation\"][\"exterior_image_2_left\"],\n                )\n                wrist_img = traj[\"observation\"][\"wrist_image_left\"]\n                # Randomly sample one of the three language instructions\n                instruction = tf.random.shuffle(\n                    [traj[\"language_instruction\"], traj[\"language_instruction_2\"], traj[\"language_instruction_3\"]]\n                )[0]\n\n                traj_len = tf.shape(traj[\"action\"])[0]\n                indices = tf.as_string(tf.range(traj_len))\n\n                # Data filtering:\n                # Compute a uniquely-identifying step ID by concatenating the recording folderpath, file path,\n                # and each step's time step index. This will index into the filter hash table, and if it returns true,\n                # then the frame passes the filter.\n                step_id = (\n                    traj[\"traj_metadata\"][\"episode_metadata\"][\"recording_folderpath\"]\n                    + \"--\"\n                    + traj[\"traj_metadata\"][\"episode_metadata\"][\"file_path\"]\n                    + \"--\"\n                    + indices\n                )\n                passes_filter = self.filter_table.lookup(step_id)\n\n                return {\n                    \"actions\": actions,\n                    \"observation\": {\n                        \"image\": exterior_img,\n                        \"wrist_image\": wrist_img,\n                        \"joint_position\": traj[\"observation\"][\"joint_position\"],\n                        \"gripper_position\": traj[\"observation\"][\"gripper_position\"],\n                    },\n                    \"prompt\": instruction,\n                    \"step_id\": step_id,\n                    \"passes_filter\": passes_filter,\n                }\n\n            dataset = dataset.traj_map(restructure, num_parallel_calls)\n\n            def chunk_actions(traj):\n                \"\"\"Splits episode into action chunks.\"\"\"\n                traj_len = tf.shape(traj[\"actions\"])[0]\n\n                # For each step in the trajectory, construct indices for the next n actions\n                action_chunk_indices = tf.broadcast_to(\n                    tf.range(action_chunk_size)[None],\n                    [traj_len, action_chunk_size],\n                ) + tf.broadcast_to(\n                    tf.range(traj_len)[:, None],\n                    [traj_len, action_chunk_size],\n                )\n\n                # Cap to length of the sequence --> final chunks will repeat the last action\n                # This makes sense, since we are using absolute joint + gripper position actions\n                action_chunk_indices = tf.minimum(action_chunk_indices, traj_len - 1)\n\n                # Gather the actions for each chunk\n                traj[\"actions\"] = tf.gather(traj[\"actions\"], action_chunk_indices)\n                return traj\n\n            dataset = dataset.traj_map(chunk_actions, num_parallel_calls)\n\n            # Flatten: map from trajectory dataset to dataset of individual action chunks\n            dataset = dataset.flatten(num_parallel_calls=num_parallel_calls)\n\n            # Filter data that doesn't pass the filter\n            def filter_from_dict(frame):\n                return frame[\"passes_filter\"]\n\n            dataset = dataset.filter(filter_from_dict)\n\n            # Remove \"passes_filter\" key from output\n            def remove_passes_filter(frame):\n                frame.pop(\"passes_filter\")\n                return frame\n\n            dataset = dataset.map(remove_passes_filter)\n\n            # Decode images: RLDS saves encoded images, only decode now for efficiency\n            def decode_images(traj):\n                traj[\"observation\"][\"image\"] = tf.io.decode_image(\n                    traj[\"observation\"][\"image\"], expand_animations=False, dtype=tf.uint8\n                )\n                traj[\"observation\"][\"wrist_image\"] = tf.io.decode_image(\n                    traj[\"observation\"][\"wrist_image\"], expand_animations=False, dtype=tf.uint8\n                )\n                return traj\n\n            return dataset.frame_map(decode_images, num_parallel_calls)\n\n        logging.info(f\"Preparing {len(datasets)} datasets...\")\n        logging.info(\"-\" * 50)\n        for dataset in datasets:\n            logging.info(f\"    {dataset.name}:{dataset.version} with weight {dataset.weight:.2f}\")\n        logging.info(\"-\" * 50)\n        all_datasets = [prepare_single_dataset(dataset) for dataset in datasets]\n        weights = [dataset.weight for dataset in datasets]\n\n        final_dataset = dl.DLataset.sample_from_datasets(all_datasets, weights=weights)\n        final_dataset = final_dataset.shuffle(shuffle_buffer_size)\n        final_dataset = final_dataset.batch(batch_size)\n        # Note =>> Seems to reduce memory usage without affecting speed?\n        final_dataset = final_dataset.with_ram_budget(1)\n\n        self.dataset = final_dataset\n        self.batch_size = batch_size\n        self.shuffle = shuffle\n\n    def __iter__(self):\n        yield from self.dataset.as_numpy_iterator()\n\n    def __len__(self):\n        # This is the approximate number of samples in DROID after filtering.\n        # Easier to hardcode than to iterate through the dataset and compute it.\n        return 20_000_000\n"
  },
  {
    "path": "src/openpi/training/misc/polaris_config.py",
    "content": "\"\"\"PolaRiS baseline policy configs.\"\"\"\n\nfrom typing import TypeAlias\n\nimport openpi.models.model as _model\nimport openpi.models.pi0_config as pi0_config\nimport openpi.models.pi0_fast as pi0_fast\nimport openpi.models.tokenizer as _tokenizer\nimport openpi.policies.droid_policy as droid_policy\nimport openpi.training.droid_rlds_dataset as droid_rlds_dataset\nimport openpi.training.optimizer as _optimizer\nimport openpi.training.weight_loaders as weight_loaders\nimport openpi.transforms as _transforms\n\nModelType: TypeAlias = _model.ModelType\n\n\ndef get_polaris_configs():\n    # Import here to avoid circular imports.\n    from openpi.training.config import AssetsConfig\n    from openpi.training.config import RLDSDroidDataConfig\n    from openpi.training.config import SimpleDataConfig\n    from openpi.training.config import TrainConfig\n\n    return [\n        #\n        # PolaRiS DROID jointpos policies\n        #\n        TrainConfig(\n            name=\"pi05_droid_jointpos_polaris\",\n            model=pi0_config.Pi0Config(action_horizon=15, pi05=True),\n            data=RLDSDroidDataConfig(\n                assets=AssetsConfig(\n                    assets_dir=\"gs://openpi-assets/checkpoints/polaris/pi05_droid_jointpos_polaris/assets\",\n                    asset_id=\"droid\",\n                ),\n                datasets=(\n                    droid_rlds_dataset.RLDSDataset(\n                        name=\"droid\",\n                        version=\"1.0.1\",\n                        weight=0.9,\n                        filter_dict_path=\"gs://openpi-assets/droid/droid_sample_ranges_v1_0_1.json\",\n                    ),\n                    droid_rlds_dataset.RLDSDataset(\n                        name=\"polaris_droid_cotrain_dataset\",\n                        version=\"1.0.0\",\n                        weight=0.1,\n                        filter_dict_path=\"gs://openpi-assets/droid/polaris_droid_cotrain_dataset_sample_ranges_v1_0_0.json\",\n                    ),\n                ),\n                rlds_data_dir=\"<path_to_droid_rlds_dataset>\",\n                action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION,\n            ),\n            weight_loader=weight_loaders.CheckpointWeightLoader(\n                \"gs://openpi-assets/checkpoints/polaris/pi05_droid_jointpos_polaris/params\"\n            ),\n            lr_schedule=_optimizer.CosineDecaySchedule(\n                warmup_steps=1_000,\n                peak_lr=5e-5,\n                decay_steps=1_000_000,\n                decay_lr=5e-5,\n            ),\n            num_train_steps=1_000,\n            batch_size=128,\n            log_interval=100,\n            save_interval=1000,\n            keep_period=1000,\n            num_workers=0,  # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally\n        ),\n        TrainConfig(\n            name=\"pi0_fast_droid_jointpos_polaris\",\n            model=pi0_fast.Pi0FASTConfig(\n                action_dim=8,\n                action_horizon=10,\n                max_token_len=180,\n            ),\n            data=RLDSDroidDataConfig(\n                assets=AssetsConfig(\n                    assets_dir=\"gs://openpi-assets/checkpoints/polaris/pi0_fast_droid_jointpos_polaris/assets\",\n                    asset_id=\"droid\",\n                ),\n                datasets=(\n                    droid_rlds_dataset.RLDSDataset(\n                        name=\"droid\",\n                        version=\"1.0.1\",\n                        weight=0.9,\n                        filter_dict_path=\"gs://openpi-assets/droid/droid_sample_ranges_v1_0_1.json\",\n                    ),\n                    droid_rlds_dataset.RLDSDataset(\n                        name=\"polaris_droid_cotrain_dataset\",\n                        version=\"1.0.0\",\n                        weight=0.1,\n                        filter_dict_path=\"gs://openpi-assets/droid/polaris_droid_cotrain_dataset_sample_ranges_v1_0_0.json\",\n                    ),\n                ),\n                rlds_data_dir=\"<path_to_droid_rlds_dataset>\",\n                action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION,\n            ),\n            weight_loader=weight_loaders.CheckpointWeightLoader(\n                \"gs://openpi-assets/checkpoints/polaris/pi0_fast_droid_jointpos_polaris/params\"\n            ),\n            lr_schedule=_optimizer.CosineDecaySchedule(\n                warmup_steps=1_000,\n                peak_lr=5e-5,\n                decay_steps=1_000_000,\n                decay_lr=5e-5,\n            ),\n            num_train_steps=1_000,\n            batch_size=128,\n            log_interval=100,\n            save_interval=1000,\n            keep_period=1000,\n            num_workers=0,  # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally\n        ),\n        TrainConfig(\n            name=\"pi0_droid_jointpos_polaris\",\n            model=pi0_config.Pi0Config(\n                # action_dim=8, # leave as 32 default...\n                action_horizon=10,\n                max_token_len=100,\n            ),\n            data=RLDSDroidDataConfig(\n                assets=AssetsConfig(\n                    assets_dir=\"gs://openpi-assets/checkpoints/polaris/pi0_droid_jointpos_polaris/assets\",\n                    asset_id=\"droid\",\n                ),\n                datasets=(\n                    droid_rlds_dataset.RLDSDataset(\n                        name=\"droid\",\n                        version=\"1.0.1\",\n                        weight=0.9,\n                        filter_dict_path=\"gs://openpi-assets/droid/droid_sample_ranges_v1_0_1.json\",\n                    ),\n                    droid_rlds_dataset.RLDSDataset(\n                        name=\"polaris_droid_cotrain_dataset\",\n                        version=\"1.0.0\",\n                        weight=0.1,\n                        filter_dict_path=\"gs://openpi-assets/droid/polaris_droid_cotrain_dataset_sample_ranges_v1_0_0.json\",\n                    ),\n                ),\n                rlds_data_dir=\"<path_to_droid_rlds_dataset>\",\n                action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION,\n            ),\n            weight_loader=weight_loaders.CheckpointWeightLoader(\n                \"gs://openpi-assets/checkpoints/polaris/pi0_droid_jointpos_polaris/params\"\n            ),\n            lr_schedule=_optimizer.CosineDecaySchedule(\n                warmup_steps=1_000,\n                peak_lr=5e-5,\n                decay_steps=1_000_000,\n                decay_lr=5e-5,\n            ),\n            num_train_steps=1_000,\n            batch_size=128,\n            log_interval=100,\n            save_interval=1000,\n            keep_period=1000,\n            num_workers=0,  # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally\n        ),\n        TrainConfig(\n            name=\"pi0_droid_jointpos_100k_polaris\",\n            model=pi0_config.Pi0Config(\n                # action_dim=8, # leave as 32 default...\n                action_horizon=10,\n                max_token_len=100,\n            ),\n            data=RLDSDroidDataConfig(\n                assets=AssetsConfig(\n                    assets_dir=\"gs://openpi-assets/checkpoints/polaris/pi0_droid_jointpos_100k_polaris/assets\",\n                    asset_id=\"droid\",\n                ),\n                datasets=(\n                    droid_rlds_dataset.RLDSDataset(\n                        name=\"droid\",\n                        version=\"1.0.1\",\n                        weight=0.9,\n                        filter_dict_path=\"gs://openpi-assets/droid/droid_sample_ranges_v1_0_1.json\",\n                    ),\n                    droid_rlds_dataset.RLDSDataset(\n                        name=\"polaris_droid_cotrain_dataset\",\n                        version=\"1.0.0\",\n                        weight=0.1,\n                        filter_dict_path=\"gs://openpi-assets/droid/polaris_droid_cotrain_dataset_sample_ranges_v1_0_0.json\",\n                    ),\n                ),\n                rlds_data_dir=\"<path_to_droid_rlds_dataset>\",\n                action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION,\n            ),\n            weight_loader=weight_loaders.CheckpointWeightLoader(\n                \"gs://openpi-assets/checkpoints/polaris/pi0_droid_jointpos_100k_polaris/params\"\n            ),\n            lr_schedule=_optimizer.CosineDecaySchedule(\n                warmup_steps=1_000,\n                peak_lr=5e-5,\n                decay_steps=1_000_000,\n                decay_lr=5e-5,\n            ),\n            num_train_steps=1_000,\n            batch_size=128,\n            log_interval=100,\n            save_interval=1000,\n            keep_period=1000,\n            num_workers=0,  # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally\n        ),\n        # openpi doesn't support finetuning of binning policies, so this is an inference-only config\n        TrainConfig(\n            name=\"paligemma_binning_droid_jointpos\",\n            model=pi0_fast.Pi0FASTConfig(\n                action_dim=8,\n                action_horizon=15,\n                max_token_len=600,\n                fast_model_tokenizer=_tokenizer.BinningTokenizer,\n            ),\n            data=SimpleDataConfig(\n                assets=AssetsConfig(asset_id=\"droid\"),\n                data_transforms=lambda model: _transforms.Group(\n                    inputs=[droid_policy.DroidInputs(model_type=ModelType.PI0_FAST)],\n                    outputs=[\n                        _transforms.AbsoluteActions(_transforms.make_bool_mask(7, -1)),\n                        droid_policy.DroidOutputs(),\n                    ],\n                ),\n            ),\n        ),\n    ]\n"
  },
  {
    "path": "src/openpi/training/misc/roboarena_config.py",
    "content": "\"\"\"RoboArena baseline policy configs.\"\"\"\n\nfrom typing import TypeAlias\n\nimport openpi.models.model as _model\nimport openpi.models.pi0_config as pi0_config\nimport openpi.models.pi0_fast as pi0_fast\nimport openpi.models.tokenizer as _tokenizer\nimport openpi.policies.droid_policy as droid_policy\nimport openpi.transforms as _transforms\n\nModelType: TypeAlias = _model.ModelType\n\n\ndef get_roboarena_configs():\n    # Import here to avoid circular imports.\n    from openpi.training.config import AssetsConfig\n    from openpi.training.config import DataConfig\n    from openpi.training.config import SimpleDataConfig\n    from openpi.training.config import TrainConfig\n\n    return [\n        #\n        # RoboArena DROID baseline inference configs.\n        #\n        TrainConfig(\n            # Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.\n            name=\"paligemma_binning_droid\",\n            model=pi0_fast.Pi0FASTConfig(\n                action_dim=8,\n                action_horizon=15,\n                max_token_len=400,\n                fast_model_tokenizer=_tokenizer.BinningTokenizer,\n            ),\n            data=SimpleDataConfig(\n                assets=AssetsConfig(asset_id=\"droid\"),\n                data_transforms=lambda model: _transforms.Group(\n                    inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],\n                    outputs=[droid_policy.DroidOutputs()],\n                ),\n                base_config=DataConfig(\n                    prompt_from_task=True,\n                ),\n            ),\n        ),\n        TrainConfig(\n            # Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).\n            name=\"paligemma_fast_droid\",\n            model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),\n            data=SimpleDataConfig(\n                assets=AssetsConfig(asset_id=\"droid\"),\n                data_transforms=lambda model: _transforms.Group(\n                    inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],\n                    outputs=[droid_policy.DroidOutputs()],\n                ),\n                base_config=DataConfig(\n                    prompt_from_task=True,\n                ),\n            ),\n        ),\n        TrainConfig(\n            # Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).\n            name=\"paligemma_fast_specialist_droid\",\n            model=pi0_fast.Pi0FASTConfig(\n                action_dim=8,\n                action_horizon=15,\n                fast_model_tokenizer=_tokenizer.FASTTokenizer,\n                fast_model_tokenizer_kwargs={\"fast_tokenizer_path\": \"KarlP/fast_droid_specialist\"},\n            ),\n            data=SimpleDataConfig(\n                assets=AssetsConfig(asset_id=\"droid\"),\n                data_transforms=lambda model: _transforms.Group(\n                    inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],\n                    outputs=[droid_policy.DroidOutputs()],\n                ),\n                base_config=DataConfig(\n                    prompt_from_task=True,\n                ),\n            ),\n        ),\n        TrainConfig(\n            # Trained from PaliGemma, using FSQ tokenizer.\n            name=\"paligemma_vq_droid\",\n            model=pi0_fast.Pi0FASTConfig(\n                action_dim=8,\n                action_horizon=15,\n                fast_model_tokenizer=_tokenizer.FSQTokenizer,\n                fast_model_tokenizer_kwargs={\"fsq_tokenizer_path\": \"gs://openpi-assets/tokenizers/droid_fsq_tokenizer\"},\n            ),\n            data=SimpleDataConfig(\n                assets=AssetsConfig(asset_id=\"droid\"),\n                data_transforms=lambda model: _transforms.Group(\n                    inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],\n                    outputs=[droid_policy.DroidOutputs()],\n                ),\n                base_config=DataConfig(\n                    prompt_from_task=True,\n                ),\n            ),\n        ),\n        TrainConfig(\n            # pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.\n            name=\"paligemma_diffusion_droid\",\n            model=pi0_config.Pi0Config(action_horizon=10, action_dim=8),\n            data=SimpleDataConfig(\n                assets=AssetsConfig(asset_id=\"droid\"),\n                data_transforms=lambda model: _transforms.Group(\n                    inputs=[droid_policy.DroidInputs(action_dim=model.action_dim)],\n                    outputs=[droid_policy.DroidOutputs()],\n                ),\n                base_config=DataConfig(\n                    prompt_from_task=True,\n                ),\n            ),\n        ),\n    ]\n"
  },
  {
    "path": "src/openpi/training/optimizer.py",
    "content": "import dataclasses\nfrom typing import Protocol, runtime_checkable\n\nimport jax.numpy as jnp\nimport optax\n\nimport openpi.shared.array_typing as at\n\n\n@runtime_checkable\nclass LRScheduleConfig(Protocol):\n    def create(self) -> optax.Schedule: ...\n\n\n@dataclasses.dataclass(frozen=True)\nclass CosineDecaySchedule(LRScheduleConfig):\n    \"\"\"Cosine decay schedule with warmup.\"\"\"\n\n    warmup_steps: int = 1_000\n    peak_lr: float = 2.5e-5\n    decay_steps: int = 30_000\n    decay_lr: float = 2.5e-6\n\n    def create(self) -> optax.Schedule:\n        return optax.warmup_cosine_decay_schedule(\n            init_value=self.peak_lr / (self.warmup_steps + 1),\n            peak_value=self.peak_lr,\n            warmup_steps=self.warmup_steps,\n            decay_steps=self.decay_steps,\n            end_value=self.decay_lr,\n        )\n\n\n@dataclasses.dataclass(frozen=True)\nclass RsqrtDecaySchedule(LRScheduleConfig):\n    \"\"\"Inverse square root decay schedule with warmup.\"\"\"\n\n    warmup_steps: int = 1_000\n    peak_lr: float = 5e-5\n    timescale: float = 10_000\n\n    def create(self) -> optax.Schedule:\n        return optax.join_schedules(\n            [\n                optax.linear_schedule(\n                    init_value=self.peak_lr / (self.warmup_steps + 1),\n                    end_value=self.peak_lr,\n                    transition_steps=self.warmup_steps,\n                ),\n                lambda step: self.peak_lr / jnp.sqrt((self.timescale + step) / self.timescale),\n            ],\n            [self.warmup_steps],\n        )\n\n\n@runtime_checkable\nclass OptimizerConfig(Protocol):\n    def create(\n        self,\n        lr: optax.ScalarOrSchedule,\n        weight_decay_mask: at.PyTree | None = None,\n    ) -> optax.GradientTransformation: ...\n\n\n@dataclasses.dataclass(frozen=True)\nclass AdamW(OptimizerConfig):\n    \"\"\"AdamW optimizer.\"\"\"\n\n    b1: float = 0.9\n    b2: float = 0.95\n    eps: float = 1e-8\n    # Changing this to 0 can cause out-of-memory errors for some reason, so we set it to a negligible value.\n    weight_decay: float = 1e-10\n    clip_gradient_norm: float = 1.0\n\n    def create(\n        self,\n        lr: optax.ScalarOrSchedule,\n        weight_decay_mask: at.PyTree | None = None,\n    ) -> optax.GradientTransformation:\n        tx = optax.adamw(\n            lr, b1=self.b1, b2=self.b2, eps=self.eps, weight_decay=self.weight_decay, mask=weight_decay_mask\n        )\n\n        return optax.chain(optax.clip_by_global_norm(self.clip_gradient_norm), tx)\n\n\n@dataclasses.dataclass(frozen=True)\nclass SGD(OptimizerConfig):\n    \"\"\"SGD optimizer.\"\"\"\n\n    lr: float = 5e-5\n    momentum: float = 0.9\n    nesterov: bool = False\n\n    def create(\n        self,\n        lr: optax.ScalarOrSchedule,\n        weight_decay_mask: at.PyTree | None = None,\n    ) -> optax.GradientTransformation:\n        assert weight_decay_mask is None, \"Weight decay is not supported for SGD\"\n        return optax.sgd(lr, momentum=self.momentum, nesterov=self.nesterov)\n\n\ndef create_optimizer(\n    optimizer: OptimizerConfig, lr_schedule: LRScheduleConfig, weight_decay_mask: at.PyTree | None = None\n) -> optax.GradientTransformation:\n    lr = lr_schedule.create()\n    return optimizer.create(lr, weight_decay_mask=weight_decay_mask)\n"
  },
  {
    "path": "src/openpi/training/sharding.py",
    "content": "import contextlib\nimport logging\n\nimport jax\nimport numpy as np\n\nBATCH_AXIS = \"batch\"\nFSDP_AXIS = \"fsdp\"\n# In FSDP, we shard the data across both the batch and FSDP axes.\nDATA_AXIS = (BATCH_AXIS, FSDP_AXIS)\n\n\nclass _MeshState:\n    active_mesh: jax.sharding.Mesh | None = None\n\n\ndef make_mesh(num_fsdp_devices: int) -> jax.sharding.Mesh:\n    if jax.device_count() % num_fsdp_devices != 0:\n        raise ValueError(\n            f\"Number of devices {jax.device_count()} must be divisible by the number of FSDP devices {num_fsdp_devices}.\"\n        )\n    mesh_shape = (jax.device_count() // num_fsdp_devices, num_fsdp_devices)\n    return jax.make_mesh(mesh_shape, (BATCH_AXIS, FSDP_AXIS))\n\n\n@contextlib.contextmanager\ndef set_mesh(mesh: jax.sharding.Mesh):\n    \"\"\"Plumbing the mesh deep into the module tree is extremely cumbersome; until the JAX team lands a better API, a\n    custom context manager like this one is the recommended way to maintain a reference to a global mesh. This is only used\n    in `activation_sharding_constraint` below.\"\"\"\n    if _MeshState.active_mesh is not None:\n        raise ValueError(\"Cannot nest set_mesh context managers.\")\n    _MeshState.active_mesh = mesh\n    try:\n        yield\n    finally:\n        _MeshState.active_mesh = None\n\n\ndef activation_sharding_constraint(pytree):\n    if _MeshState.active_mesh is None:\n        return pytree\n    return jax.lax.with_sharding_constraint(\n        pytree, jax.sharding.NamedSharding(_MeshState.active_mesh, jax.sharding.PartitionSpec(DATA_AXIS))\n    )\n\n\ndef fsdp_sharding(\n    pytree,\n    mesh: jax.sharding.Mesh,\n    *,\n    min_size_mbytes: int = 4,  # 4 MiB\n    log: bool = False,\n):\n    \"\"\"Apply FSDP sharding to a pytree of arrays based on the mesh shape.\n\n    Args:\n        pytree: A pytree to be apply sharding specified by the mesh, note that only array types (eg. contains .shape attr)\n          will be considered for sharding.\n        mesh: The mesh being used for applying sharding on to pytree.\n        min_size_mbytes: The minimum size of the array in MiB to be considered for sharding, any array smaller than this\n          will be replicated.\n        log: If true, will log the sharding decisions for arrays that are being considered for sharding.\n\n    Returns:\n        The sharded pytree.\n    \"\"\"\n    min_size_bytes = min_size_mbytes * 2**20\n\n    def _shard_arr(kp, array: jax.ShapeDtypeStruct):\n        # if fsdp is not actually going to be used, replicate everything to avoid extraneous logging\n        if mesh.shape[FSDP_AXIS] == 1:\n            return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())\n        # replicate scalar and vector arrays\n        if not hasattr(array, \"shape\"):\n            return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())\n        if len(array.shape) < 2:\n            return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())\n        # replicate small arrays\n        if (arr_size := np.prod(array.shape) * np.dtype(array.dtype).itemsize) < min_size_bytes:\n            return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())\n\n        # shard matrices and larger tensors along the largest axis that is divisible by the fsdp dimension\n        axes = np.argsort(array.shape)[::-1]\n        spec = [None] * len(axes)\n        for i in axes:\n            if array.shape[i] % mesh.shape[FSDP_AXIS] == 0:\n                if log:\n                    logging.info(\n                        f\"Sharding {jax.tree_util.keystr(kp)} of shape {array.shape} ({arr_size / 2**20:.2f} MiB) along axis {i}\"\n                    )\n                spec[i] = FSDP_AXIS\n                return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*spec))\n\n        # replicate if no valid sharding was found\n        if log:\n            logging.warning(\n                f\"Could not find a valid sharding for {jax.tree_util.keystr(kp)} of shape {array.shape} with mesh of shape {mesh.shape}\"\n            )\n        return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())\n\n    return jax.tree_util.tree_map_with_path(_shard_arr, pytree)\n"
  },
  {
    "path": "src/openpi/training/utils.py",
    "content": "from collections.abc import Callable\nfrom typing import Any\n\nfrom flax import nnx\nfrom flax import struct\nimport jax\nimport optax\n\nfrom openpi.models import model as _model\nfrom openpi.shared import array_typing as at\n\n\n@at.typecheck\n@struct.dataclass\nclass TrainState:\n    step: at.Int[at.ArrayLike, \"\"]\n    params: nnx.State\n    model_def: nnx.GraphDef[_model.BaseModel]\n    opt_state: optax.OptState\n    tx: optax.GradientTransformation = struct.field(pytree_node=False)\n\n    ema_decay: float | None = struct.field(pytree_node=False)\n    ema_params: nnx.State | None = None\n\n\n@at.typecheck\ndef tree_to_info(tree: at.PyTree, interp_func: Callable[[Any], str] = str) -> str:\n    \"\"\"Converts a PyTree into a human-readable string for logging. Optionally, `interp_func` can be provided to convert\n    the leaf values to more meaningful strings.\n    \"\"\"\n    tree, _ = jax.tree_util.tree_flatten_with_path(tree)\n    return \"\\n\".join(f\"{jax.tree_util.keystr(path)}: {interp_func(value)}\" for path, value in tree)\n\n\n@at.typecheck\ndef array_tree_to_info(tree: at.PyTree) -> str:\n    \"\"\"Converts a PyTree of arrays into a human-readable string for logging.\"\"\"\n    return tree_to_info(tree, lambda x: f\"{x.shape}@{x.dtype}\")\n"
  },
  {
    "path": "src/openpi/training/weight_loaders.py",
    "content": "import dataclasses\nimport logging\nimport re\nfrom typing import Protocol, runtime_checkable\n\nimport flax.traverse_util\nimport numpy as np\n\nimport openpi.models.model as _model\nimport openpi.shared.array_typing as at\nimport openpi.shared.download as download\n\nlogger = logging.getLogger(__name__)\n\n\n@runtime_checkable\nclass WeightLoader(Protocol):\n    def load(self, params: at.Params) -> at.Params:\n        \"\"\"Loads the model weights.\n\n        Args:\n            params: Parameters of the model. This is a nested structure of array-like objects that\n                represent the model's parameters.\n\n        Returns:\n            Loaded parameters. The structure must be identical to `params`. If returning a subset of\n            the parameters the loader must merge the loaded parameters with `params`.\n        \"\"\"\n\n\n@dataclasses.dataclass(frozen=True)\nclass NoOpWeightLoader(WeightLoader):\n    def load(self, params: at.Params) -> at.Params:\n        return params\n\n\n@dataclasses.dataclass(frozen=True)\nclass CheckpointWeightLoader(WeightLoader):\n    \"\"\"Loads an entire set of weights from a checkpoint.\n\n    Compatible with:\n      trained checkpoints:\n        example: \"./checkpoints/<config>/<exp>/<step>/params\"\n      released checkpoints:\n        example: \"gs://openpi-assets/checkpoints/<model>/params\"\n    \"\"\"\n\n    params_path: str\n\n    def load(self, params: at.Params) -> at.Params:\n        # We are loading np.ndarray and relying on the training code to properly convert and shard the params.\n        loaded_params = _model.restore_params(download.maybe_download(self.params_path), restore_type=np.ndarray)\n        # Add all missing LoRA weights.\n        return _merge_params(loaded_params, params, missing_regex=\".*lora.*\")\n\n\n@dataclasses.dataclass(frozen=True)\nclass PaliGemmaWeightLoader(WeightLoader):\n    \"\"\"Loads weights from the official PaliGemma checkpoint.\n\n    This will overwrite existing weights with similar names while keeping all extra weights intact.\n    This allows us to support the action expert which is used by the Pi0 model.\n    \"\"\"\n\n    def load(self, params: at.Params) -> at.Params:\n        path = download.maybe_download(\n            \"gs://vertex-model-garden-paligemma-us/paligemma/pt_224.npz\", gs={\"token\": \"anon\"}\n        )\n        with path.open(\"rb\") as f:\n            flat_params = dict(np.load(f, allow_pickle=False))\n        loaded_params = {\"PaliGemma\": flax.traverse_util.unflatten_dict(flat_params, sep=\"/\")[\"params\"]}\n        # Add all missing weights.\n        return _merge_params(loaded_params, params, missing_regex=\".*\")\n\n\ndef _merge_params(loaded_params: at.Params, params: at.Params, *, missing_regex: str) -> at.Params:\n    \"\"\"Merges the loaded parameters with the reference parameters.\n\n    Args:\n        loaded_params: The parameters to merge.\n        params: The reference parameters.\n        missing_regex: A regex pattern for all missing keys that should be merged from the reference parameters.\n\n    Returns:\n        A new dictionary with the merged parameters.\n    \"\"\"\n    flat_ref = flax.traverse_util.flatten_dict(params, sep=\"/\")\n    flat_loaded = flax.traverse_util.flatten_dict(loaded_params, sep=\"/\")\n\n    # First, take all weights that are a subset of the reference weights.\n    result = {}\n    for k, v in flat_loaded.items():\n        if k in flat_ref:\n            result[k] = v.astype(flat_ref[k].dtype) if v.dtype != flat_ref[k].dtype else v\n\n    flat_loaded.clear()\n\n    # Then, merge any missing weights as defined by the missing regex.\n    pattern = re.compile(missing_regex)\n    for k in {k for k in flat_ref if pattern.fullmatch(k)}:\n        if k not in result:\n            result[k] = flat_ref[k]\n\n    return flax.traverse_util.unflatten_dict(result, sep=\"/\")\n"
  },
  {
    "path": "src/openpi/transforms.py",
    "content": "from collections.abc import Callable, Mapping, Sequence\nimport dataclasses\nimport re\nfrom typing import Protocol, TypeAlias, TypeVar, runtime_checkable\n\nimport flax.traverse_util as traverse_util\nimport jax\nimport numpy as np\nfrom openpi_client import image_tools\n\nfrom openpi.models import tokenizer as _tokenizer\nfrom openpi.shared import array_typing as at\nfrom openpi.shared import normalize as _normalize\n\nDataDict: TypeAlias = at.PyTree\nNormStats: TypeAlias = _normalize.NormStats\n\n\nT = TypeVar(\"T\")\nS = TypeVar(\"S\")\n\n\n@runtime_checkable\nclass DataTransformFn(Protocol):\n    def __call__(self, data: DataDict) -> DataDict:\n        \"\"\"Apply transformation to the data.\n\n        Args:\n            data: The data to apply the transform to. This is a possibly nested dictionary that contains\n                unbatched data elements. Each leaf is expected to be a numpy array. Using JAX arrays is allowed\n                but not recommended since it may result in extra GPU memory usage inside data loader worker\n                processes.\n\n        Returns:\n            The transformed data. Could be the input `data` that was modified in place, or a new data structure.\n        \"\"\"\n\n\n@dataclasses.dataclass(frozen=True)\nclass Group:\n    \"\"\"A group of transforms.\"\"\"\n\n    # Transforms that are applied to the model input data.\n    inputs: Sequence[DataTransformFn] = ()\n\n    # Transforms that are applied to the model output data.\n    outputs: Sequence[DataTransformFn] = ()\n\n    def push(self, *, inputs: Sequence[DataTransformFn] = (), outputs: Sequence[DataTransformFn] = ()) -> \"Group\":\n        \"\"\"Append transforms to the group and return a new group.\n\n        Args:\n            inputs: Appended to the *end* of the current input transforms.\n            outputs: Appended to the *beginning* of the current output transforms.\n\n        Returns:\n            A new group with the appended transforms.\n        \"\"\"\n        return Group(inputs=(*self.inputs, *inputs), outputs=(*outputs, *self.outputs))\n\n\n@dataclasses.dataclass(frozen=True)\nclass CompositeTransform(DataTransformFn):\n    \"\"\"A composite transform that applies a sequence of transforms in order.\"\"\"\n\n    transforms: Sequence[DataTransformFn]\n\n    def __call__(self, data: DataDict) -> DataDict:\n        for transform in self.transforms:\n            data = transform(data)\n        return data\n\n\ndef compose(transforms: Sequence[DataTransformFn]) -> DataTransformFn:\n    \"\"\"Compose a sequence of transforms into a single transform.\"\"\"\n    return CompositeTransform(transforms)\n\n\n@dataclasses.dataclass(frozen=True)\nclass RepackTransform(DataTransformFn):\n    \"\"\"Repacks an input dictionary into a new dictionary.\n\n    Repacking is defined using a dictionary where the keys are the new keys and the values\n    are the flattened paths to the old keys. We use '/' as the separator during flattening.\n\n    Example:\n    {\n        \"images\": {\n            \"cam_high\": \"observation.images.top\",\n            \"cam_low\": \"observation.images.bottom\",\n        },\n        \"state\": \"observation.state\",\n        \"actions\": \"action\",\n    }\n    \"\"\"\n\n    structure: at.PyTree[str]\n\n    def __call__(self, data: DataDict) -> DataDict:\n        flat_item = flatten_dict(data)\n        return jax.tree.map(lambda k: flat_item[k], self.structure)\n\n\n@dataclasses.dataclass(frozen=True)\nclass InjectDefaultPrompt(DataTransformFn):\n    prompt: str | None\n\n    def __call__(self, data: DataDict) -> DataDict:\n        if self.prompt is not None and \"prompt\" not in data:\n            data[\"prompt\"] = np.asarray(self.prompt)\n        return data\n\n\n@dataclasses.dataclass(frozen=True)\nclass Normalize(DataTransformFn):\n    norm_stats: at.PyTree[NormStats] | None\n    # If true, will use quantile normalization. Otherwise, normal z-score normalization will be used.\n    use_quantiles: bool = False\n    # If true, will raise an error if any of the keys in the norm stats are not present in the data.\n    strict: bool = False\n\n    def __post_init__(self):\n        if self.norm_stats is not None and self.use_quantiles:\n            _assert_quantile_stats(self.norm_stats)\n\n    def __call__(self, data: DataDict) -> DataDict:\n        if self.norm_stats is None:\n            return data\n\n        return apply_tree(\n            data,\n            self.norm_stats,\n            self._normalize_quantile if self.use_quantiles else self._normalize,\n            strict=self.strict,\n        )\n\n    def _normalize(self, x, stats: NormStats):\n        mean, std = stats.mean[..., : x.shape[-1]], stats.std[..., : x.shape[-1]]\n        return (x - mean) / (std + 1e-6)\n\n    def _normalize_quantile(self, x, stats: NormStats):\n        assert stats.q01 is not None\n        assert stats.q99 is not None\n        q01, q99 = stats.q01[..., : x.shape[-1]], stats.q99[..., : x.shape[-1]]\n        return (x - q01) / (q99 - q01 + 1e-6) * 2.0 - 1.0\n\n\n@dataclasses.dataclass(frozen=True)\nclass Unnormalize(DataTransformFn):\n    norm_stats: at.PyTree[NormStats] | None\n    # If true, will use quantile normalization. Otherwise, normal z-score normalization will be used.\n    use_quantiles: bool = False\n\n    def __post_init__(self):\n        if self.norm_stats is not None and self.use_quantiles:\n            _assert_quantile_stats(self.norm_stats)\n\n    def __call__(self, data: DataDict) -> DataDict:\n        if self.norm_stats is None:\n            return data\n\n        # Make sure that all the keys in the norm stats are present in the data.\n        return apply_tree(\n            data,\n            self.norm_stats,\n            self._unnormalize_quantile if self.use_quantiles else self._unnormalize,\n            strict=True,\n        )\n\n    def _unnormalize(self, x, stats: NormStats):\n        mean = pad_to_dim(stats.mean, x.shape[-1], axis=-1, value=0.0)\n        std = pad_to_dim(stats.std, x.shape[-1], axis=-1, value=1.0)\n        return x * (std + 1e-6) + mean\n\n    def _unnormalize_quantile(self, x, stats: NormStats):\n        assert stats.q01 is not None\n        assert stats.q99 is not None\n        q01, q99 = stats.q01, stats.q99\n        if (dim := q01.shape[-1]) < x.shape[-1]:\n            return np.concatenate([(x[..., :dim] + 1.0) / 2.0 * (q99 - q01 + 1e-6) + q01, x[..., dim:]], axis=-1)\n        return (x + 1.0) / 2.0 * (q99 - q01 + 1e-6) + q01\n\n\n@dataclasses.dataclass(frozen=True)\nclass ResizeImages(DataTransformFn):\n    height: int\n    width: int\n\n    def __call__(self, data: DataDict) -> DataDict:\n        data[\"image\"] = {k: image_tools.resize_with_pad(v, self.height, self.width) for k, v in data[\"image\"].items()}\n        return data\n\n\n@dataclasses.dataclass(frozen=True)\nclass SubsampleActions(DataTransformFn):\n    stride: int\n\n    def __call__(self, data: DataDict) -> DataDict:\n        data[\"actions\"] = data[\"actions\"][:: self.stride]\n        return data\n\n\n@dataclasses.dataclass(frozen=True)\nclass DeltaActions(DataTransformFn):\n    \"\"\"Repacks absolute actions into delta action space.\"\"\"\n\n    # Boolean mask for the action dimensions to be repacked into delta action space. Length\n    # can be smaller than the actual number of dimensions. If None, this transform is a no-op.\n    # See `make_bool_mask` for more details.\n    mask: Sequence[bool] | None\n\n    def __call__(self, data: DataDict) -> DataDict:\n        if \"actions\" not in data or self.mask is None:\n            return data\n\n        state, actions = data[\"state\"], data[\"actions\"]\n        mask = np.asarray(self.mask)\n        dims = mask.shape[-1]\n        actions[..., :dims] -= np.expand_dims(np.where(mask, state[..., :dims], 0), axis=-2)\n        data[\"actions\"] = actions\n\n        return data\n\n\n@dataclasses.dataclass(frozen=True)\nclass AbsoluteActions(DataTransformFn):\n    \"\"\"Repacks delta actions into absolute action space.\"\"\"\n\n    # Boolean mask for the action dimensions to be repacked into absolute action space. Length\n    # can be smaller than the actual number of dimensions. If None, this transform is a no-op.\n    # See `make_bool_mask` for more details.\n    mask: Sequence[bool] | None\n\n    def __call__(self, data: DataDict) -> DataDict:\n        if \"actions\" not in data or self.mask is None:\n            return data\n\n        state, actions = data[\"state\"], data[\"actions\"]\n        mask = np.asarray(self.mask)\n        dims = mask.shape[-1]\n        actions[..., :dims] += np.expand_dims(np.where(mask, state[..., :dims], 0), axis=-2)\n        data[\"actions\"] = actions\n\n        return data\n\n\n@dataclasses.dataclass(frozen=True)\nclass TokenizePrompt(DataTransformFn):\n    tokenizer: _tokenizer.PaligemmaTokenizer\n    discrete_state_input: bool = False\n\n    def __call__(self, data: DataDict) -> DataDict:\n        if (prompt := data.pop(\"prompt\", None)) is None:\n            raise ValueError(\"Prompt is required\")\n\n        if self.discrete_state_input:\n            if (state := data.get(\"state\", None)) is None:\n                raise ValueError(\"State is required.\")\n        else:\n            state = None\n\n        if not isinstance(prompt, str):\n            prompt = prompt.item()\n\n        tokens, token_masks = self.tokenizer.tokenize(prompt, state)\n        return {**data, \"tokenized_prompt\": tokens, \"tokenized_prompt_mask\": token_masks}\n\n\n@dataclasses.dataclass(frozen=True)\nclass TokenizeFASTInputs(DataTransformFn):\n    tokenizer: _tokenizer.FASTTokenizer\n\n    def __call__(self, data: DataDict) -> DataDict:\n        if (prompt := data.pop(\"prompt\", None)) is None:\n            raise ValueError(\"Prompt is required\")\n\n        if not isinstance(prompt, str):\n            prompt = prompt.item()\n\n        state, actions = data[\"state\"], data.get(\"actions\")\n        tokens, token_mask, ar_mask, loss_mask = self.tokenizer.tokenize(prompt, state, actions)\n        return {\n            **data,\n            \"tokenized_prompt\": tokens,\n            \"tokenized_prompt_mask\": token_mask,\n            \"token_ar_mask\": ar_mask,\n            \"token_loss_mask\": loss_mask,\n        }\n\n\n@dataclasses.dataclass(frozen=True)\nclass ExtractFASTActions(DataTransformFn):\n    tokenizer: _tokenizer.FASTTokenizer\n    action_horizon: int\n    action_dim: int\n\n    def __call__(self, data: DataDict) -> DataDict:\n        if \"actions\" not in data:\n            return data\n        # Model outputs are saved in \"actions\", but for FAST models they represent tokens.\n        tokens = data.pop(\"actions\")\n        actions = self.tokenizer.extract_actions(tokens.astype(np.int32), self.action_horizon, self.action_dim)\n        return {\n            **data,\n            \"actions\": actions,\n        }\n\n\n@dataclasses.dataclass(frozen=True)\nclass PromptFromLeRobotTask(DataTransformFn):\n    \"\"\"Extracts a prompt from the current LeRobot dataset task.\"\"\"\n\n    # Contains the LeRobot dataset tasks (dataset.meta.tasks).\n    tasks: dict[int, str]\n\n    def __call__(self, data: DataDict) -> DataDict:\n        if \"task_index\" not in data:\n            raise ValueError('Cannot extract prompt without \"task_index\"')\n\n        task_index = int(data[\"task_index\"])\n        if (prompt := self.tasks.get(task_index)) is None:\n            raise ValueError(f\"{task_index=} not found in task mapping: {self.tasks}\")\n\n        return {**data, \"prompt\": prompt}\n\n\n@dataclasses.dataclass(frozen=True)\nclass PadStatesAndActions(DataTransformFn):\n    \"\"\"Zero-pads states and actions to the model action dimension.\"\"\"\n\n    model_action_dim: int\n\n    def __call__(self, data: DataDict) -> DataDict:\n        data[\"state\"] = pad_to_dim(data[\"state\"], self.model_action_dim, axis=-1)\n        if \"actions\" in data:\n            data[\"actions\"] = pad_to_dim(data[\"actions\"], self.model_action_dim, axis=-1)\n        return data\n\n\ndef flatten_dict(tree: at.PyTree) -> dict:\n    \"\"\"Flatten a nested dictionary. Uses '/' as the separator.\"\"\"\n    return traverse_util.flatten_dict(tree, sep=\"/\")\n\n\ndef unflatten_dict(tree: dict) -> at.PyTree:\n    \"\"\"Unflatten a flattened dictionary. Assumes that '/' was used as a separator.\"\"\"\n    return traverse_util.unflatten_dict(tree, sep=\"/\")\n\n\ndef transform_dict(patterns: Mapping[str, str | None], tree: at.PyTree) -> at.PyTree:\n    \"\"\"Transform the structure of a nested dictionary using a set of patterns.\n\n    The transformation is defined using the `patterns` dictionary. The keys are the\n    input keys that should be matched and the values are the new names inside the output\n    dictionary. If the value is None, the input key is removed.\n\n    Both keys and values should represent flattened paths using '/' as the separator.\n    Keys can be regular expressions and values can include backreferences to the\n    matched groups (see `re.sub` for more details). Note that the regular expression\n    must match the entire key.\n\n    The order inside the `patterns` dictionary is important. Only the first pattern that\n    matches the input key will be used.\n\n    See unit tests for more examples.\n\n    Args:\n        patterns: A mapping from old keys to new keys.\n        tree: The nested dictionary to transform.\n\n    Returns:\n        The transformed nested dictionary.\n    \"\"\"\n    data = flatten_dict(tree)\n\n    # Compile the patterns.\n    compiled = {re.compile(k): v for k, v in patterns.items()}\n\n    output = {}\n    for k in data:\n        for pattern, repl in compiled.items():\n            if pattern.fullmatch(k):\n                new_k = pattern.sub(repl, k, count=1) if repl is not None else None\n                break\n        else:\n            # Use the original key if no match is found.\n            new_k = k\n\n        if new_k is not None:\n            if new_k in output:\n                raise ValueError(f\"Key '{new_k}' already exists in output\")\n            output[new_k] = data[k]\n\n    # Validate the output structure to make sure that it can be unflattened.\n    names = sorted(output)\n    for i in range(len(names) - 1):\n        name, next_name = names[i : i + 2]\n        if next_name.startswith(name + \"/\"):\n            raise ValueError(f\"Leaf '{name}' aliases a node of '{next_name}'\")\n\n    return unflatten_dict(output)\n\n\ndef apply_tree(\n    tree: at.PyTree[T], selector: at.PyTree[S], fn: Callable[[T, S], T], *, strict: bool = False\n) -> at.PyTree[T]:\n    tree = flatten_dict(tree)\n    selector = flatten_dict(selector)\n\n    def transform(k: str, v: T) -> T:\n        if k in selector:\n            return fn(v, selector[k])\n        return v\n\n    if strict:\n        for k in selector:\n            if k not in tree:\n                raise ValueError(f\"Selector key {k} not found in tree\")\n\n    return unflatten_dict({k: transform(k, v) for k, v in tree.items()})\n\n\ndef pad_to_dim(x: np.ndarray, target_dim: int, axis: int = -1, value: float = 0.0) -> np.ndarray:\n    \"\"\"Pad an array to the target dimension with zeros along the specified axis.\"\"\"\n    current_dim = x.shape[axis]\n    if current_dim < target_dim:\n        pad_width = [(0, 0)] * len(x.shape)\n        pad_width[axis] = (0, target_dim - current_dim)\n        return np.pad(x, pad_width, constant_values=value)\n    return x\n\n\ndef make_bool_mask(*dims: int) -> tuple[bool, ...]:\n    \"\"\"Make a boolean mask for the given dimensions.\n\n    Example:\n        make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)\n        make_bool_mask(2, 0, 2) == (True, True, True, True)\n\n    Args:\n        dims: The dimensions to make the mask for.\n\n    Returns:\n        A tuple of booleans.\n    \"\"\"\n    result = []\n    for dim in dims:\n        if dim > 0:\n            result.extend([True] * (dim))\n        else:\n            result.extend([False] * (-dim))\n    return tuple(result)\n\n\ndef _assert_quantile_stats(norm_stats: at.PyTree[NormStats]) -> None:\n    for k, v in flatten_dict(norm_stats).items():\n        if v.q01 is None or v.q99 is None:\n            raise ValueError(\n                f\"quantile stats must be provided if use_quantile_norm is True. Key {k} is missing q01 or q99.\"\n            )\n"
  },
  {
    "path": "src/openpi/transforms_test.py",
    "content": "import numpy as np\nimport pytest\n\nimport openpi.models.tokenizer as _tokenizer\nimport openpi.transforms as _transforms\n\n\ndef test_repack_transform():\n    transform = _transforms.RepackTransform(\n        structure={\n            \"a\": {\"b\": \"b/c\"},\n            \"d\": \"e/f\",\n        }\n    )\n    item = {\"b\": {\"c\": 1}, \"e\": {\"f\": 2}}\n    assert transform(item) == {\"a\": {\"b\": 1}, \"d\": 2}\n\n\ndef test_delta_actions():\n    item = {\"state\": np.array([1, 2, 3]), \"actions\": np.array([[3, 4, 5], [5, 6, 7]])}\n\n    transform = _transforms.DeltaActions(mask=[False, True])\n    transformed = transform(item)\n\n    assert np.all(transformed[\"state\"] == np.array([1, 2, 3]))\n    assert np.all(transformed[\"actions\"] == np.array([[3, 2, 5], [5, 4, 7]]))\n\n\ndef test_delta_actions_noop():\n    item = {\"state\": np.array([1, 2, 3]), \"actions\": np.array([[3, 4, 5], [5, 6, 7]])}\n\n    # No-op when the mask is disabled.\n    transform = _transforms.DeltaActions(mask=None)\n    assert transform(item) is item\n\n    # No-op when there are no actions in the input.\n    del item[\"actions\"]\n    transform = _transforms.DeltaActions(mask=[True, False])\n    assert transform(item) is item\n\n\ndef test_absolute_actions():\n    item = {\"state\": np.array([1, 2, 3]), \"actions\": np.array([[3, 4, 5], [5, 6, 7]])}\n\n    transform = _transforms.AbsoluteActions(mask=[False, True])\n    transformed = transform(item)\n\n    assert np.all(transformed[\"state\"] == np.array([1, 2, 3]))\n    assert np.all(transformed[\"actions\"] == np.array([[3, 6, 5], [5, 8, 7]]))\n\n\ndef test_absolute_actions_noop():\n    item = {\"state\": np.array([1, 2, 3]), \"actions\": np.array([[3, 4, 5], [5, 6, 7]])}\n\n    # No-op when the mask is disabled.\n    transform = _transforms.AbsoluteActions(mask=None)\n    assert transform(item) is item\n\n    # No-op when there are no actions in the input.\n    del item[\"actions\"]\n    transform = _transforms.AbsoluteActions(mask=[True, False])\n    assert transform(item) is item\n\n\ndef test_make_bool_mask():\n    assert _transforms.make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)\n    assert _transforms.make_bool_mask(2, 0, 2) == (True, True, True, True)\n\n\ndef test_tokenize_prompt():\n    tokenizer = _tokenizer.PaligemmaTokenizer(max_len=12)\n    transform = _transforms.TokenizePrompt(tokenizer)\n\n    data = transform({\"prompt\": \"Hello, world!\"})\n\n    tok_prompt, tok_mask = tokenizer.tokenize(\"Hello, world!\")\n    assert np.allclose(tok_prompt, data[\"tokenized_prompt\"])\n    assert np.allclose(tok_mask, data[\"tokenized_prompt_mask\"])\n\n\ndef test_tokenize_no_prompt():\n    transform = _transforms.TokenizePrompt(_tokenizer.PaligemmaTokenizer())\n\n    with pytest.raises(ValueError, match=\"Prompt is required\"):\n        transform({})\n\n\ndef test_transform_dict():\n    # Rename and remove keys.\n    input = {\"a\": {\"b\": 1, \"c\": 2}}\n    output = _transforms.transform_dict({\"a/b\": \"a/c\", \"a/c\": None}, input)\n    assert output == {\"a\": {\"c\": 1}}\n\n    # Raises and error since the renamed key conflicts with an existing key.\n    with pytest.raises(ValueError, match=\"Key 'a/c' already exists in output\"):\n        _transforms.transform_dict({\"a/b\": \"a/c\"}, input)\n\n    # Full match is required and so nothing will be removed.\n    input = {\"a\": {\"b\": 1, \"c\": 2}}\n    output = _transforms.transform_dict({\"a\": None}, input)\n    assert output == input\n\n    # The regex matches the entire key and so the entire input will be removed.\n    input = {\"a\": {\"b\": 1, \"c\": 2}}\n    output = _transforms.transform_dict({\"a.+\": None}, input)\n    assert output == {}\n\n    # Replace keys using backreferences. All leaves named 'c' are replaced with 'd'.\n    input = {\"a\": {\"b\": 1, \"c\": 1}, \"b\": {\"c\": 2}}\n    output = _transforms.transform_dict({\"(.+)/c\": r\"\\1/d\"}, input)\n    assert output == {\"a\": {\"b\": 1, \"d\": 1}, \"b\": {\"d\": 2}}\n\n\ndef test_extract_prompt_from_task():\n    transform = _transforms.PromptFromLeRobotTask({1: \"Hello, world!\"})\n\n    data = transform({\"task_index\": 1})\n    assert data[\"prompt\"] == \"Hello, world!\"\n\n    with pytest.raises(ValueError, match=\"task_index=2 not found in task mapping\"):\n        transform({\"task_index\": 2})\n"
  }
]